diff options
Diffstat (limited to 'src/main.rs')
| -rw-r--r-- | src/main.rs | 540 |
1 files changed, 540 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..9a139ad --- /dev/null +++ b/src/main.rs @@ -0,0 +1,540 @@ +//! Run the PostgreSQL server off a temporary data directory. +//! +//! See the [supported command-line arguments](Args). + +use std::{ + collections::HashMap, + ffi::{OsStr, OsString}, + os::unix::fs::PermissionsExt, + path::PathBuf, + process::{ExitCode, ExitStatus, Stdio}, + time::Duration, +}; + +use anyhow::{Context, bail}; +use clap::{Parser, ValueHint, builder::NonEmptyStringValueParser}; +use tokio::{process::Command, signal::unix::SignalKind, time::Instant}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info, trace, warn}; +use url::Url; + +// For clap. +macro_rules! version { + () => { + crate::version() + }; +} + +const EXTRA_HELP: &str = color_print::cstr!( + r#"<strong><underline>Exit status:</underline></strong> + +Exits 0 on success, and >0 if an error occurs. Specifically, exits 2 on usage error, and 130 when interrupted. Propagates the exit status of the given command, if any. +"# +); + +/// Command-line arguments. +#[derive(Debug, Clone, Parser)] +#[command(name = "temp-postgres", version, long_version = version!(), about, author, after_long_help = EXTRA_HELP, next_display_order = None)] +#[deny(clippy::missing_docs_in_private_items)] +struct Args { + /// Database name. + /// + /// Defaults to the PostgreSQL user name. + #[arg(short = 'd', long, env = "PGDATABASE", value_parser = NonEmptyStringValueParser::new())] + dbname: Option<String>, + + /// PostgreSQL user name. + /// + /// If not set, PostgreSQL uses the current operating system user. + #[arg(short = 'u', long, env = "PGUSER", value_parser = NonEmptyStringValueParser::new())] + username: Option<String>, + + /// Create a symbolic link to the temporary directory. + /// + /// The given path must not exist. + /// + /// The symbolic link will be created when the PostgreSQL server is ready. + /// It will be removed when the program ends if it (still) points to the temporary directory. + /// + /// Example: Static client configuration. + /// + /// ```sh + /// temp-postgres --symlink db + /// psql --host "$(realpath db)" + /// ``` + // Clap ensures that arguments passed on the command-line are non-empty. + // Add a value parser when binding to an environment variable. + #[arg(long, value_name = "PATH", value_hint = ValueHint::FilePath)] + symlink: Option<OsString>, + + /// Log level. + #[arg(long, value_enum, default_value_t = LogLevel::Info)] + log_level: LogLevel, + + /// Postgres startup timeout. + #[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")] + startup_timeout: Duration, + + /// Graceful shutdown timeout. + /// + /// `temp-postgres` performs a graceful shutdown upon SIGINT. + /// The signal is propagated to child processes, namely the `postgres` command, and the wrapped command, if any. + /// Child processes are killed (with SIGKILL) if they don't shut down in time. + #[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")] + shutdown_timeout: Duration, + + /// Command to execute once the PostgreSQL server is ready. + /// + /// The following environment variables will be passed to the command: + /// + /// - PGHOST, the absolute path to the directory in which the UNIX domain socket file is stored + /// - PGDATABASE, the database name + /// - PGUSER, the PostgreSQL user name + /// - DATABASE_URL, a connection URI + /// + /// See also the PostgreSQL documentation on environment variables and connection URIs: + /// + /// - <https://www.postgresql.org/docs/current/libpq-envars.html> + /// - <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS> + /// + /// Example: Wrap the `psql` command to connect to the temporary database once the server is ready. + /// + /// ```sh + /// temp-postgres -- psql + /// ``` + #[arg(last = true, value_hint = ValueHint::CommandWithArguments)] + command: Vec<OsString>, +} + +/// Log level. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)] +enum LogLevel { + Off, + Error, + Warn, + Info, + Debug, + Trace, +} + +impl From<LogLevel> for tracing_subscriber::filter::LevelFilter { + fn from(value: LogLevel) -> Self { + match value { + LogLevel::Off => tracing_subscriber::filter::LevelFilter::OFF, + LogLevel::Error => tracing_subscriber::filter::LevelFilter::ERROR, + LogLevel::Warn => tracing_subscriber::filter::LevelFilter::WARN, + LogLevel::Info => tracing_subscriber::filter::LevelFilter::INFO, + LogLevel::Debug => tracing_subscriber::filter::LevelFilter::DEBUG, + LogLevel::Trace => tracing_subscriber::filter::LevelFilter::TRACE, + } + } +} + +#[tokio::main] +async fn main() -> ExitCode { + let args = Args::parse(); // May exit 0 or 2. Respects `NO_COLOR`. + let result = async { + let mut subscriber = tracing_subscriber::fmt() + .with_max_level(args.log_level.to_owned()) + .with_writer(std::io::stderr); // Enable user to capture standard output of wrapped command. + if std::env::var_os("NO_COLOR").is_some() { + subscriber = subscriber.with_ansi(false); + } + subscriber.init(); + debug!(version = version(), "starting"); + debug!(arguments = ?args, "parsed command-line arguments"); + + run(args).await + }; + match result.await { + Ok(_) => ExitCode::SUCCESS, + Err(err) => { + // Relying on tracing to print error. + //eprintln!("Error: {err:?}"); + if err.is::<InterruptError>() { + ExitCode::from(130) + } else if let Some(err) = err.downcast_ref::<WrappedCommandError>() { + match err.0.code().and_then(|code| u8::try_from(code).ok()) { + Some(code) => ExitCode::from(code), + None => ExitCode::FAILURE, + } + } else { + ExitCode::FAILURE + } + } + } +} + +#[tracing::instrument(level = "info", skip_all, err(Debug))] +async fn run(args: Args) -> Result<(), anyhow::Error> { + let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()) + .context("Failed to register SIGINT handler")?; + debug!("registered SIGINT handler"); + + let tmp = tempfile::Builder::new() + .prefix("temp-postgres-") + .permissions(std::fs::Permissions::from_mode(0o700)) + .tempdir() + .context("Failed to create temporary directory")?; + debug!(directory = %&tmp.path().display(), "created temporary directory"); + + let token = CancellationToken::new(); + let mut server = tokio::spawn(serve(args.clone(), tmp.path().to_path_buf(), token.clone())); + debug!("spawned server task"); + + let result = tokio::select! { + _ = sigint.recv() => { + info!("received SIGINT, shutting down ..."); + token.cancel(); + tokio::select! { + _ = tokio::time::sleep(args.shutdown_timeout) => { + warn!("graceful shutdown timed out, aborting ..."); + server.abort(); + server.await? + } + r = &mut server => { + debug!("server task finished"); + r? + } + } + } + r = &mut server => { + debug!("server task finished"); + r? + }, + }; + + if let Some(symlink) = &args.symlink { + if let Ok(target) = tokio::fs::canonicalize(&symlink).await + && let Ok(tmp) = tokio::fs::canonicalize(&tmp.path()).await + && target == tmp + { + match tokio::fs::remove_file(symlink).await { + Ok(()) => debug!(?symlink, "removed symlink"), + Err(err) => warn!(?symlink, error = ?err, "failed to remove symlink"), + } + } else { + warn!(?symlink, directory = %&tmp.path().display(), "keeping symlink because it doesn't point to temporary directory"); + } + } + + let tmp_dir = tmp.path().to_path_buf(); + match tmp.close() { + Ok(()) => debug!(directory = %tmp_dir.display(), "removed temporary directory"), + Err(err) => { + warn!(directory = %tmp_dir.display(), error = ?err, "failed to remove temporary directory") + } + } + result +} + +#[tracing::instrument(level = "info", skip_all, err)] +async fn serve( + args: Args, + tmp_dir: PathBuf, + token: CancellationToken, +) -> Result<(), anyhow::Error> { + // WORKAROUND: PostgreSQL expects a non-empty dbname when username is set. + let args = Args { + dbname: args.dbname.clone().or(args.username.clone()), + ..args + }; + + // TODO: Check token.is_cancelled() before every command? + let mut optional_initdb_args = Vec::<&OsStr>::new(); + if let Some(ref username) = args.username { + optional_initdb_args.push("--username".as_ref()); + optional_initdb_args.push(username.as_ref()); + } + let status = Command::new("initdb") + .arg("--pgdata") + .arg(&tmp_dir) + .args(optional_initdb_args) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .kill_on_drop(true) // In case this task gets aborted + .status() + .await + .context("Failed to spawn or wait for initdb command")?; + if !status.success() { + bail!("initdb command exited error: {status}"); + } + debug!("initialized database"); + + let mut postgres = Command::new("postgres") + .arg("-h") + .arg("") // Don't bind to a TCP/IP address + .arg("-k") + .arg(&tmp_dir) + .arg("-D") + .arg(&tmp_dir) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .kill_on_drop(true) + .spawn() + .context("Failed to spawn postgres command")?; + debug!("spawned postgres command"); + + let now = Instant::now(); + tokio::time::sleep(Duration::from_millis(50)).await; + for i in 1.. { + match is_ready(tmp_dir.as_ref(), args.username.as_deref()).await { + Ok(status) => { + if status.success() { + debug!("PostgreSQL server is ready"); + break; + } else if status.code() == Some(1) || status.code() == Some(2) { + if now.elapsed() < args.startup_timeout { + trace!("retrying pg_isready command ..."); + } else { + bail!("pg_isready timed out after {i} attempts"); + } + } else { + bail!("pg_isready command exited error: {status}"); + } + } + Err(err) => return Err(err), + } + tokio::time::sleep(Duration::from_millis(200)).await; + } + + let mut optional_createdb_args = Vec::<&OsStr>::new(); + if let Some(ref username) = args.username { + optional_createdb_args.push("--username".as_ref()); + optional_createdb_args.push(username.as_ref()); + } + if let Some(ref dbname) = args.dbname { + // Positional argument + optional_createdb_args.push(dbname.as_ref()); + } + let status = Command::new("createdb") + .arg("--host") + .arg(&tmp_dir) + .arg("--no-password") + .args(optional_createdb_args) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .kill_on_drop(true) + .status() + .await + .context("Failed to spawn or wait for createdb command")?; + if !status.success() { + bail!("createdb command exited error: {status}"); + } + debug!("Created PostgreSQL database"); + + let mut url = Url::parse("postgresql://")?; + if let Some(ref dbname) = args.dbname { + url.set_path(dbname); + } + url.query_pairs_mut() + .append_pair("host", &tmp_dir.display().to_string()); + if let Some(ref username) = args.username { + url.query_pairs_mut().append_pair("user", username); + } + + info!( + PGHOST = %&tmp_dir.display(), + PGDATABASE = &args.dbname, + PGUSER = &args.username, + DATABASE_URL = %&url, + "You can connect using the following parameters" + ); + + if let Some(symlink) = &args.symlink { + tokio::fs::symlink(&tmp_dir, symlink) + .await + .context("Failed to create symlink")?; + debug!(?symlink, "created symlink"); + } + + let mut wrapped_command = match args.command.split_first() { + None => None, + Some((head, tail)) => { + let mut optional_env = HashMap::new(); + if let Some(ref username) = args.username { + optional_env.insert("PGUSER", username); + } + if let Some(ref dbname) = args.dbname { + optional_env.insert("PGDATABASE", dbname); + } + let command = Command::new(head) + .args(tail) + .env("PGHOST", &tmp_dir) + .env("DATABASE_URL", url.as_str()) // for SQLx with SQLX_OFFLINE=true + .envs(optional_env) + .kill_on_drop(true) + .spawn() + .context("Failed to spawn wrapped command")?; + debug!("spawned wrapped command"); + Some(command) + } + }; + + tokio::select! { + _ = token.cancelled() => { + debug!("token cancelled"); + debug!("shutting down child processes ..."); + // Propagating SIGINT to child processes in case it was send to temp-postgres only. + // Halving shutdown timeout to avoid being aborted by parent task. + let shutdown_timeout = args.shutdown_timeout.checked_div(2).unwrap(); + let (wrapped_command_result, postgres_result) = tokio::join!( + async { + if let Some(ref mut c) = wrapped_command { + interrupt_or_kill(c, shutdown_timeout).await + } else { + Ok(None) + } + }, + interrupt_or_kill(&mut postgres, shutdown_timeout), + ); + if wrapped_command.is_some() { + match wrapped_command_result { + // Exit code is probably 130 + Ok(status) => debug!(?status, "successfully shut down wrapped command"), + Err(err) => warn!(error = ?err, "failed to shut down wrapped command"), + } + } + match postgres_result { + // Exit code is probably 130 + Ok(status) => debug!(?status, "successfully shut down postgres command"), + Err(err) => warn!(error = ?err, "failed to shut down postgres command"), + } + Err(InterruptError{})? + }, + r = postgres.wait() => { + debug!("postgres command finished"); + debug!("shutting down wrapped command ..."); + if let Some(ref mut c) = wrapped_command { + match interrupt_or_kill(c, args.shutdown_timeout).await { + // Exit code is probably 130 + Ok(status) => debug!(?status, "successfully shut down wrapped command"), + Err(err) => warn!(error = ?err, "failed to shut down wrapped command"), + } + } + let status = r.context("failed to wait for postgres command")?; + if !status.success() { + bail!("postgres command exited error: {status}"); + } + Ok(()) + }, + r = conditional_wait(&mut wrapped_command) => { + debug!("wrapped command finished"); + debug!("shutting down postgres command ..."); + match interrupt_or_kill(&mut postgres, args.shutdown_timeout).await { + // Exit code is probably 130 + Ok(status) => debug!(?status, "successfully shut down postgres command"), + Err(err) => warn!(error = ?err, "failed to shut down postgres command"), + } + let status = r.context("failed to wait for wrapped command")?; + if !status.success() { + return Err(WrappedCommandError(status))?; + } + Ok(()) + }, + } +} + +fn interrupt(child: &tokio::process::Child) -> Result<(), anyhow::Error> { + // std::os::unix::process::ChildExt::send_signal is nightly-only experimental + if let Some(id) = child.id() { + nix::sys::signal::kill( + nix::unistd::Pid::from_raw(id as i32), + nix::sys::signal::Signal::SIGINT, + ) + .context("Failed to send SIGINT to process {id}")?; + } + Ok(()) +} + +async fn interrupt_or_kill( + child: &mut tokio::process::Child, + timeout: Duration, +) -> Result<Option<ExitStatus>, anyhow::Error> { + interrupt(child)?; + tokio::select! { + r = child.wait() => { + r.map(Some).context("Failed to wait for child process") + }, + _ = tokio::time::sleep(timeout) => { + child.kill().await.context("Failed to kill child process")?; + Ok(None) + } + } +} + +async fn is_ready(tmp_dir: &OsStr, username: Option<&str>) -> Result<ExitStatus, anyhow::Error> { + let mut optional_isready_args = Vec::<&OsStr>::new(); + if let Some(ref username) = username { + optional_isready_args.push("--username".as_ref()); + optional_isready_args.push(username.as_ref()); + } + Command::new("pg_isready") + .arg("--host") + .arg(tmp_dir) + .arg("--dbname") + .arg("dummy") // No database created yet + .arg("--timeout") + .arg("3") // Default + .args(optional_isready_args) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .kill_on_drop(true) + .status() + .await + .context("Failed to spawn or wait for pg_isready command") +} + +async fn conditional_wait( + child: &mut Option<tokio::process::Child>, +) -> Result<std::process::ExitStatus, std::io::Error> { + match child { + None => std::future::pending().await, + Some(child) => child.wait().await, + } +} + +#[derive(Debug)] +struct WrappedCommandError(ExitStatus); + +impl std::fmt::Display for WrappedCommandError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Wrapped command exited error: {}", self.0) + } +} + +impl std::error::Error for WrappedCommandError {} + +#[derive(Debug)] +struct InterruptError {} + +impl std::fmt::Display for InterruptError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Interrupted") + } +} + +impl std::error::Error for InterruptError {} + +fn version() -> String { + // https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates + let package_version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); + let revision = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown"); + let dirty = option_env!("VERGEN_GIT_DIRTY").and_then(|s| s.trim().parse::<bool>().ok()); + + let mut version = format!("{package_version}-{revision}"); + if let Some(true) = dirty { + version.push_str("-dirty"); + } + version +} + +#[test] +fn cli() { + use clap::CommandFactory; + Args::command().debug_assert(); +} |