//! Run the PostgreSQL server off a temporary data directory. //! //! See the [supported command-line arguments](Args). use std::{ collections::HashMap, ffi::{OsStr, OsString}, os::unix::fs::PermissionsExt, path::PathBuf, process::{ExitCode, ExitStatus, Stdio}, time::Duration, }; use anyhow::{Context, bail}; use clap::{Parser, ValueHint, builder::NonEmptyStringValueParser}; use tokio::{process::Command, signal::unix::SignalKind, time::Instant}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, trace, warn}; use url::Url; // For clap. macro_rules! version { () => { crate::version() }; } const EXTRA_HELP: &str = color_print::cstr!( r#"Exit status: Exits 0 on success, and >0 if an error occurs. Specifically, exits 2 on usage error, and 130 when interrupted. Propagates the exit status of the given command, if any. "# ); /// Command-line arguments. #[derive(Debug, Clone, Parser)] #[command(name = "temp-postgres", version, long_version = version!(), about, author, after_long_help = EXTRA_HELP, next_display_order = None)] #[deny(clippy::missing_docs_in_private_items)] struct Args { /// Database name. /// /// Defaults to the PostgreSQL user name. #[arg(short = 'd', long, env = "PGDATABASE", value_parser = NonEmptyStringValueParser::new())] dbname: Option, /// PostgreSQL user name. /// /// If not set, PostgreSQL uses the current operating system user. #[arg(short = 'u', long, env = "PGUSER", value_parser = NonEmptyStringValueParser::new())] username: Option, /// Create a symbolic link to the temporary directory. /// /// The given path must not exist. /// /// The symbolic link will be created when the PostgreSQL server is ready. /// It will be removed when the program ends if it (still) points to the temporary directory. /// /// Example: Static client configuration. /// /// ```sh /// temp-postgres --symlink db /// psql --host "$(realpath db)" /// ``` // Clap ensures that arguments passed on the command-line are non-empty. // Add a value parser when binding to an environment variable. #[arg(long, value_name = "PATH", value_hint = ValueHint::FilePath)] symlink: Option, /// Log level. #[arg(long, value_enum, default_value_t = LogLevel::Info)] log_level: LogLevel, /// Postgres startup timeout. #[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")] startup_timeout: Duration, /// Graceful shutdown timeout. /// /// `temp-postgres` performs a graceful shutdown upon SIGINT. /// The signal is propagated to child processes, namely the `postgres` command, and the wrapped command, if any. /// Child processes are killed (with SIGKILL) if they don't shut down in time. #[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")] shutdown_timeout: Duration, /// Command to execute once the PostgreSQL server is ready. /// /// The following environment variables will be passed to the command: /// /// - PGHOST, the absolute path to the directory in which the UNIX domain socket file is stored /// - PGDATABASE, the database name /// - PGUSER, the PostgreSQL user name /// - DATABASE_URL, a connection URI /// /// See also the PostgreSQL documentation on environment variables and connection URIs: /// /// - /// - /// /// Example: Wrap the `psql` command to connect to the temporary database once the server is ready. /// /// ```sh /// temp-postgres -- psql /// ``` #[arg(last = true, value_hint = ValueHint::CommandWithArguments)] command: Vec, } /// Log level. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)] enum LogLevel { Off, Error, Warn, Info, Debug, Trace, } impl From for tracing_subscriber::filter::LevelFilter { fn from(value: LogLevel) -> Self { match value { LogLevel::Off => tracing_subscriber::filter::LevelFilter::OFF, LogLevel::Error => tracing_subscriber::filter::LevelFilter::ERROR, LogLevel::Warn => tracing_subscriber::filter::LevelFilter::WARN, LogLevel::Info => tracing_subscriber::filter::LevelFilter::INFO, LogLevel::Debug => tracing_subscriber::filter::LevelFilter::DEBUG, LogLevel::Trace => tracing_subscriber::filter::LevelFilter::TRACE, } } } #[tokio::main] async fn main() -> ExitCode { let args = Args::parse(); // May exit 0 or 2. Respects `NO_COLOR`. let result = async { let mut subscriber = tracing_subscriber::fmt() .with_max_level(args.log_level.to_owned()) .with_writer(std::io::stderr); // Enable user to capture standard output of wrapped command. if std::env::var_os("NO_COLOR").is_some() { subscriber = subscriber.with_ansi(false); } subscriber.init(); debug!(version = version(), "starting"); debug!(arguments = ?args, "parsed command-line arguments"); run(args).await }; match result.await { Ok(_) => ExitCode::SUCCESS, Err(err) => { // Relying on tracing to print error. //eprintln!("Error: {err:?}"); if err.is::() { ExitCode::from(130) } else if let Some(err) = err.downcast_ref::() { match err.0.code().and_then(|code| u8::try_from(code).ok()) { Some(code) => ExitCode::from(code), None => ExitCode::FAILURE, } } else { ExitCode::FAILURE } } } } #[tracing::instrument(level = "info", skip_all, err(Debug))] async fn run(args: Args) -> Result<(), anyhow::Error> { let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()) .context("Failed to register SIGINT handler")?; debug!("registered SIGINT handler"); let tmp = tempfile::Builder::new() .prefix("temp-postgres-") .permissions(std::fs::Permissions::from_mode(0o700)) .tempdir() .context("Failed to create temporary directory")?; debug!(directory = %&tmp.path().display(), "created temporary directory"); let token = CancellationToken::new(); let mut server = tokio::spawn(serve(args.clone(), tmp.path().to_path_buf(), token.clone())); debug!("spawned server task"); let result = tokio::select! { _ = sigint.recv() => { info!("received SIGINT, shutting down ..."); token.cancel(); tokio::select! { _ = tokio::time::sleep(args.shutdown_timeout) => { warn!("graceful shutdown timed out, aborting ..."); server.abort(); server.await? } r = &mut server => { debug!("server task finished"); r? } } } r = &mut server => { debug!("server task finished"); r? }, }; if let Some(symlink) = &args.symlink { if let Ok(target) = tokio::fs::canonicalize(&symlink).await && let Ok(tmp) = tokio::fs::canonicalize(&tmp.path()).await && target == tmp { match tokio::fs::remove_file(symlink).await { Ok(()) => debug!(?symlink, "removed symlink"), Err(err) => warn!(?symlink, error = ?err, "failed to remove symlink"), } } else { warn!(?symlink, directory = %&tmp.path().display(), "keeping symlink because it doesn't point to temporary directory"); } } let tmp_dir = tmp.path().to_path_buf(); match tmp.close() { Ok(()) => debug!(directory = %tmp_dir.display(), "removed temporary directory"), Err(err) => { warn!(directory = %tmp_dir.display(), error = ?err, "failed to remove temporary directory") } } result } #[tracing::instrument(level = "info", skip_all, err)] async fn serve( args: Args, tmp_dir: PathBuf, token: CancellationToken, ) -> Result<(), anyhow::Error> { // WORKAROUND: PostgreSQL expects a non-empty dbname when username is set. let args = Args { dbname: args.dbname.clone().or(args.username.clone()), ..args }; // TODO: Check token.is_cancelled() before every command? let mut optional_initdb_args = Vec::<&OsStr>::new(); if let Some(ref username) = args.username { optional_initdb_args.push("--username".as_ref()); optional_initdb_args.push(username.as_ref()); } let status = Command::new("initdb") .arg("--pgdata") .arg(&tmp_dir) .args(optional_initdb_args) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .kill_on_drop(true) // In case this task gets aborted .status() .await .context("Failed to spawn or wait for initdb command")?; if !status.success() { bail!("initdb command exited error: {status}"); } debug!("initialized database"); let mut postgres = Command::new("postgres") .arg("-h") .arg("") // Don't bind to a TCP/IP address .arg("-k") .arg(&tmp_dir) .arg("-D") .arg(&tmp_dir) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .kill_on_drop(true) .spawn() .context("Failed to spawn postgres command")?; debug!("spawned postgres command"); let now = Instant::now(); tokio::time::sleep(Duration::from_millis(50)).await; for i in 1.. { match is_ready(tmp_dir.as_ref(), args.username.as_deref()).await { Ok(status) => { if status.success() { debug!("PostgreSQL server is ready"); break; } else if status.code() == Some(1) || status.code() == Some(2) { if now.elapsed() < args.startup_timeout { trace!("retrying pg_isready command ..."); } else { bail!("pg_isready timed out after {i} attempts"); } } else { bail!("pg_isready command exited error: {status}"); } } Err(err) => return Err(err), } tokio::time::sleep(Duration::from_millis(200)).await; } let mut optional_createdb_args = Vec::<&OsStr>::new(); if let Some(ref username) = args.username { optional_createdb_args.push("--username".as_ref()); optional_createdb_args.push(username.as_ref()); } if let Some(ref dbname) = args.dbname { // Positional argument optional_createdb_args.push(dbname.as_ref()); } let status = Command::new("createdb") .arg("--host") .arg(&tmp_dir) .arg("--no-password") .args(optional_createdb_args) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .kill_on_drop(true) .status() .await .context("Failed to spawn or wait for createdb command")?; if !status.success() { bail!("createdb command exited error: {status}"); } debug!("Created PostgreSQL database"); let mut url = Url::parse("postgresql://")?; if let Some(ref dbname) = args.dbname { url.set_path(dbname); } url.query_pairs_mut() .append_pair("host", &tmp_dir.display().to_string()); if let Some(ref username) = args.username { url.query_pairs_mut().append_pair("user", username); } info!( PGHOST = %&tmp_dir.display(), PGDATABASE = &args.dbname, PGUSER = &args.username, DATABASE_URL = %&url, "You can connect using the following parameters" ); if let Some(symlink) = &args.symlink { tokio::fs::symlink(&tmp_dir, symlink) .await .context("Failed to create symlink")?; debug!(?symlink, "created symlink"); } let mut wrapped_command = match args.command.split_first() { None => None, Some((head, tail)) => { let mut optional_env = HashMap::new(); if let Some(ref username) = args.username { optional_env.insert("PGUSER", username); } if let Some(ref dbname) = args.dbname { optional_env.insert("PGDATABASE", dbname); } let command = Command::new(head) .args(tail) .env("PGHOST", &tmp_dir) .env("DATABASE_URL", url.as_str()) // for SQLx with SQLX_OFFLINE=true .envs(optional_env) .kill_on_drop(true) .spawn() .context("Failed to spawn wrapped command")?; debug!("spawned wrapped command"); Some(command) } }; tokio::select! { _ = token.cancelled() => { debug!("token cancelled"); debug!("shutting down child processes ..."); // Propagating SIGINT to child processes in case it was send to temp-postgres only. // Halving shutdown timeout to avoid being aborted by parent task. let shutdown_timeout = args.shutdown_timeout.checked_div(2).unwrap(); let (wrapped_command_result, postgres_result) = tokio::join!( async { if let Some(ref mut c) = wrapped_command { interrupt_or_kill(c, shutdown_timeout).await } else { Ok(None) } }, interrupt_or_kill(&mut postgres, shutdown_timeout), ); if wrapped_command.is_some() { match wrapped_command_result { // Exit code is probably 130 Ok(status) => debug!(?status, "successfully shut down wrapped command"), Err(err) => warn!(error = ?err, "failed to shut down wrapped command"), } } match postgres_result { // Exit code is probably 130 Ok(status) => debug!(?status, "successfully shut down postgres command"), Err(err) => warn!(error = ?err, "failed to shut down postgres command"), } Err(InterruptError{})? }, r = postgres.wait() => { debug!("postgres command finished"); debug!("shutting down wrapped command ..."); if let Some(ref mut c) = wrapped_command { match interrupt_or_kill(c, args.shutdown_timeout).await { // Exit code is probably 130 Ok(status) => debug!(?status, "successfully shut down wrapped command"), Err(err) => warn!(error = ?err, "failed to shut down wrapped command"), } } let status = r.context("failed to wait for postgres command")?; if !status.success() { bail!("postgres command exited error: {status}"); } Ok(()) }, r = conditional_wait(&mut wrapped_command) => { debug!("wrapped command finished"); debug!("shutting down postgres command ..."); match interrupt_or_kill(&mut postgres, args.shutdown_timeout).await { // Exit code is probably 130 Ok(status) => debug!(?status, "successfully shut down postgres command"), Err(err) => warn!(error = ?err, "failed to shut down postgres command"), } let status = r.context("failed to wait for wrapped command")?; if !status.success() { return Err(WrappedCommandError(status))?; } Ok(()) }, } } fn interrupt(child: &tokio::process::Child) -> Result<(), anyhow::Error> { // std::os::unix::process::ChildExt::send_signal is nightly-only experimental if let Some(id) = child.id() { nix::sys::signal::kill( nix::unistd::Pid::from_raw(id as i32), nix::sys::signal::Signal::SIGINT, ) .context("Failed to send SIGINT to process {id}")?; } Ok(()) } async fn interrupt_or_kill( child: &mut tokio::process::Child, timeout: Duration, ) -> Result, anyhow::Error> { interrupt(child)?; tokio::select! { r = child.wait() => { r.map(Some).context("Failed to wait for child process") }, _ = tokio::time::sleep(timeout) => { child.kill().await.context("Failed to kill child process")?; Ok(None) } } } async fn is_ready(tmp_dir: &OsStr, username: Option<&str>) -> Result { let mut optional_isready_args = Vec::<&OsStr>::new(); if let Some(ref username) = username { optional_isready_args.push("--username".as_ref()); optional_isready_args.push(username.as_ref()); } Command::new("pg_isready") .arg("--host") .arg(tmp_dir) .arg("--dbname") .arg("dummy") // No database created yet .arg("--timeout") .arg("3") // Default .args(optional_isready_args) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .kill_on_drop(true) .status() .await .context("Failed to spawn or wait for pg_isready command") } async fn conditional_wait( child: &mut Option, ) -> Result { match child { None => std::future::pending().await, Some(child) => child.wait().await, } } #[derive(Debug)] struct WrappedCommandError(ExitStatus); impl std::fmt::Display for WrappedCommandError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Wrapped command exited error: {}", self.0) } } impl std::error::Error for WrappedCommandError {} #[derive(Debug)] struct InterruptError {} impl std::fmt::Display for InterruptError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Interrupted") } } impl std::error::Error for InterruptError {} fn version() -> String { // https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates let package_version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); let revision = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown"); let dirty = option_env!("VERGEN_GIT_DIRTY").and_then(|s| s.trim().parse::().ok()); let mut version = format!("{package_version}-{revision}"); if let Some(true) = dirty { version.push_str("-dirty"); } version } #[test] fn cli() { use clap::CommandFactory; Args::command().debug_assert(); }