summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Kreutz <mail@skreutz.com>2026-04-30 10:10:46 +0200
committerStefan Kreutz <mail@skreutz.com>2026-04-30 10:10:46 +0200
commit46a3d2ba70decd1931e13c190bfa49217e57718d (patch)
tree49bc767c52d0cb4cf8443782cae1cc641ef59343 /src
parent47421e41def84ab92a52906f01266b1044fbfe29 (diff)
downloadtemp-postgres-46a3d2ba70decd1931e13c190bfa49217e57718d.tar.gz
Rewrite in async Rust
Diffstat (limited to 'src')
-rw-r--r--src/main.rs540
1 files changed, 540 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..9a139ad
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,540 @@
+//! Run the PostgreSQL server off a temporary data directory.
+//!
+//! See the [supported command-line arguments](Args).
+
+use std::{
+ collections::HashMap,
+ ffi::{OsStr, OsString},
+ os::unix::fs::PermissionsExt,
+ path::PathBuf,
+ process::{ExitCode, ExitStatus, Stdio},
+ time::Duration,
+};
+
+use anyhow::{Context, bail};
+use clap::{Parser, ValueHint, builder::NonEmptyStringValueParser};
+use tokio::{process::Command, signal::unix::SignalKind, time::Instant};
+use tokio_util::sync::CancellationToken;
+use tracing::{debug, info, trace, warn};
+use url::Url;
+
+// For clap.
+macro_rules! version {
+ () => {
+ crate::version()
+ };
+}
+
+const EXTRA_HELP: &str = color_print::cstr!(
+ r#"<strong><underline>Exit status:</underline></strong>
+
+Exits 0 on success, and >0 if an error occurs. Specifically, exits 2 on usage error, and 130 when interrupted. Propagates the exit status of the given command, if any.
+"#
+);
+
+/// Command-line arguments.
+#[derive(Debug, Clone, Parser)]
+#[command(name = "temp-postgres", version, long_version = version!(), about, author, after_long_help = EXTRA_HELP, next_display_order = None)]
+#[deny(clippy::missing_docs_in_private_items)]
+struct Args {
+ /// Database name.
+ ///
+ /// Defaults to the PostgreSQL user name.
+ #[arg(short = 'd', long, env = "PGDATABASE", value_parser = NonEmptyStringValueParser::new())]
+ dbname: Option<String>,
+
+ /// PostgreSQL user name.
+ ///
+ /// If not set, PostgreSQL uses the current operating system user.
+ #[arg(short = 'u', long, env = "PGUSER", value_parser = NonEmptyStringValueParser::new())]
+ username: Option<String>,
+
+ /// Create a symbolic link to the temporary directory.
+ ///
+ /// The given path must not exist.
+ ///
+ /// The symbolic link will be created when the PostgreSQL server is ready.
+ /// It will be removed when the program ends if it (still) points to the temporary directory.
+ ///
+ /// Example: Static client configuration.
+ ///
+ /// ```sh
+ /// temp-postgres --symlink db
+ /// psql --host "$(realpath db)"
+ /// ```
+ // Clap ensures that arguments passed on the command-line are non-empty.
+ // Add a value parser when binding to an environment variable.
+ #[arg(long, value_name = "PATH", value_hint = ValueHint::FilePath)]
+ symlink: Option<OsString>,
+
+ /// Log level.
+ #[arg(long, value_enum, default_value_t = LogLevel::Info)]
+ log_level: LogLevel,
+
+ /// Postgres startup timeout.
+ #[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")]
+ startup_timeout: Duration,
+
+ /// Graceful shutdown timeout.
+ ///
+ /// `temp-postgres` performs a graceful shutdown upon SIGINT.
+ /// The signal is propagated to child processes, namely the `postgres` command, and the wrapped command, if any.
+ /// Child processes are killed (with SIGKILL) if they don't shut down in time.
+ #[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")]
+ shutdown_timeout: Duration,
+
+ /// Command to execute once the PostgreSQL server is ready.
+ ///
+ /// The following environment variables will be passed to the command:
+ ///
+ /// - PGHOST, the absolute path to the directory in which the UNIX domain socket file is stored
+ /// - PGDATABASE, the database name
+ /// - PGUSER, the PostgreSQL user name
+ /// - DATABASE_URL, a connection URI
+ ///
+ /// See also the PostgreSQL documentation on environment variables and connection URIs:
+ ///
+ /// - <https://www.postgresql.org/docs/current/libpq-envars.html>
+ /// - <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS>
+ ///
+ /// Example: Wrap the `psql` command to connect to the temporary database once the server is ready.
+ ///
+ /// ```sh
+ /// temp-postgres -- psql
+ /// ```
+ #[arg(last = true, value_hint = ValueHint::CommandWithArguments)]
+ command: Vec<OsString>,
+}
+
+/// Log level.
+#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
+enum LogLevel {
+ Off,
+ Error,
+ Warn,
+ Info,
+ Debug,
+ Trace,
+}
+
+impl From<LogLevel> for tracing_subscriber::filter::LevelFilter {
+ fn from(value: LogLevel) -> Self {
+ match value {
+ LogLevel::Off => tracing_subscriber::filter::LevelFilter::OFF,
+ LogLevel::Error => tracing_subscriber::filter::LevelFilter::ERROR,
+ LogLevel::Warn => tracing_subscriber::filter::LevelFilter::WARN,
+ LogLevel::Info => tracing_subscriber::filter::LevelFilter::INFO,
+ LogLevel::Debug => tracing_subscriber::filter::LevelFilter::DEBUG,
+ LogLevel::Trace => tracing_subscriber::filter::LevelFilter::TRACE,
+ }
+ }
+}
+
+#[tokio::main]
+async fn main() -> ExitCode {
+ let args = Args::parse(); // May exit 0 or 2. Respects `NO_COLOR`.
+ let result = async {
+ let mut subscriber = tracing_subscriber::fmt()
+ .with_max_level(args.log_level.to_owned())
+ .with_writer(std::io::stderr); // Enable user to capture standard output of wrapped command.
+ if std::env::var_os("NO_COLOR").is_some() {
+ subscriber = subscriber.with_ansi(false);
+ }
+ subscriber.init();
+ debug!(version = version(), "starting");
+ debug!(arguments = ?args, "parsed command-line arguments");
+
+ run(args).await
+ };
+ match result.await {
+ Ok(_) => ExitCode::SUCCESS,
+ Err(err) => {
+ // Relying on tracing to print error.
+ //eprintln!("Error: {err:?}");
+ if err.is::<InterruptError>() {
+ ExitCode::from(130)
+ } else if let Some(err) = err.downcast_ref::<WrappedCommandError>() {
+ match err.0.code().and_then(|code| u8::try_from(code).ok()) {
+ Some(code) => ExitCode::from(code),
+ None => ExitCode::FAILURE,
+ }
+ } else {
+ ExitCode::FAILURE
+ }
+ }
+ }
+}
+
+#[tracing::instrument(level = "info", skip_all, err(Debug))]
+async fn run(args: Args) -> Result<(), anyhow::Error> {
+ let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
+ .context("Failed to register SIGINT handler")?;
+ debug!("registered SIGINT handler");
+
+ let tmp = tempfile::Builder::new()
+ .prefix("temp-postgres-")
+ .permissions(std::fs::Permissions::from_mode(0o700))
+ .tempdir()
+ .context("Failed to create temporary directory")?;
+ debug!(directory = %&tmp.path().display(), "created temporary directory");
+
+ let token = CancellationToken::new();
+ let mut server = tokio::spawn(serve(args.clone(), tmp.path().to_path_buf(), token.clone()));
+ debug!("spawned server task");
+
+ let result = tokio::select! {
+ _ = sigint.recv() => {
+ info!("received SIGINT, shutting down ...");
+ token.cancel();
+ tokio::select! {
+ _ = tokio::time::sleep(args.shutdown_timeout) => {
+ warn!("graceful shutdown timed out, aborting ...");
+ server.abort();
+ server.await?
+ }
+ r = &mut server => {
+ debug!("server task finished");
+ r?
+ }
+ }
+ }
+ r = &mut server => {
+ debug!("server task finished");
+ r?
+ },
+ };
+
+ if let Some(symlink) = &args.symlink {
+ if let Ok(target) = tokio::fs::canonicalize(&symlink).await
+ && let Ok(tmp) = tokio::fs::canonicalize(&tmp.path()).await
+ && target == tmp
+ {
+ match tokio::fs::remove_file(symlink).await {
+ Ok(()) => debug!(?symlink, "removed symlink"),
+ Err(err) => warn!(?symlink, error = ?err, "failed to remove symlink"),
+ }
+ } else {
+ warn!(?symlink, directory = %&tmp.path().display(), "keeping symlink because it doesn't point to temporary directory");
+ }
+ }
+
+ let tmp_dir = tmp.path().to_path_buf();
+ match tmp.close() {
+ Ok(()) => debug!(directory = %tmp_dir.display(), "removed temporary directory"),
+ Err(err) => {
+ warn!(directory = %tmp_dir.display(), error = ?err, "failed to remove temporary directory")
+ }
+ }
+ result
+}
+
+#[tracing::instrument(level = "info", skip_all, err)]
+async fn serve(
+ args: Args,
+ tmp_dir: PathBuf,
+ token: CancellationToken,
+) -> Result<(), anyhow::Error> {
+ // WORKAROUND: PostgreSQL expects a non-empty dbname when username is set.
+ let args = Args {
+ dbname: args.dbname.clone().or(args.username.clone()),
+ ..args
+ };
+
+ // TODO: Check token.is_cancelled() before every command?
+ let mut optional_initdb_args = Vec::<&OsStr>::new();
+ if let Some(ref username) = args.username {
+ optional_initdb_args.push("--username".as_ref());
+ optional_initdb_args.push(username.as_ref());
+ }
+ let status = Command::new("initdb")
+ .arg("--pgdata")
+ .arg(&tmp_dir)
+ .args(optional_initdb_args)
+ .stdin(Stdio::null())
+ .stdout(Stdio::null())
+ .stderr(Stdio::null())
+ .kill_on_drop(true) // In case this task gets aborted
+ .status()
+ .await
+ .context("Failed to spawn or wait for initdb command")?;
+ if !status.success() {
+ bail!("initdb command exited error: {status}");
+ }
+ debug!("initialized database");
+
+ let mut postgres = Command::new("postgres")
+ .arg("-h")
+ .arg("") // Don't bind to a TCP/IP address
+ .arg("-k")
+ .arg(&tmp_dir)
+ .arg("-D")
+ .arg(&tmp_dir)
+ .stdin(Stdio::null())
+ .stdout(Stdio::null())
+ .stderr(Stdio::null())
+ .kill_on_drop(true)
+ .spawn()
+ .context("Failed to spawn postgres command")?;
+ debug!("spawned postgres command");
+
+ let now = Instant::now();
+ tokio::time::sleep(Duration::from_millis(50)).await;
+ for i in 1.. {
+ match is_ready(tmp_dir.as_ref(), args.username.as_deref()).await {
+ Ok(status) => {
+ if status.success() {
+ debug!("PostgreSQL server is ready");
+ break;
+ } else if status.code() == Some(1) || status.code() == Some(2) {
+ if now.elapsed() < args.startup_timeout {
+ trace!("retrying pg_isready command ...");
+ } else {
+ bail!("pg_isready timed out after {i} attempts");
+ }
+ } else {
+ bail!("pg_isready command exited error: {status}");
+ }
+ }
+ Err(err) => return Err(err),
+ }
+ tokio::time::sleep(Duration::from_millis(200)).await;
+ }
+
+ let mut optional_createdb_args = Vec::<&OsStr>::new();
+ if let Some(ref username) = args.username {
+ optional_createdb_args.push("--username".as_ref());
+ optional_createdb_args.push(username.as_ref());
+ }
+ if let Some(ref dbname) = args.dbname {
+ // Positional argument
+ optional_createdb_args.push(dbname.as_ref());
+ }
+ let status = Command::new("createdb")
+ .arg("--host")
+ .arg(&tmp_dir)
+ .arg("--no-password")
+ .args(optional_createdb_args)
+ .stdin(Stdio::null())
+ .stdout(Stdio::null())
+ .stderr(Stdio::null())
+ .kill_on_drop(true)
+ .status()
+ .await
+ .context("Failed to spawn or wait for createdb command")?;
+ if !status.success() {
+ bail!("createdb command exited error: {status}");
+ }
+ debug!("Created PostgreSQL database");
+
+ let mut url = Url::parse("postgresql://")?;
+ if let Some(ref dbname) = args.dbname {
+ url.set_path(dbname);
+ }
+ url.query_pairs_mut()
+ .append_pair("host", &tmp_dir.display().to_string());
+ if let Some(ref username) = args.username {
+ url.query_pairs_mut().append_pair("user", username);
+ }
+
+ info!(
+ PGHOST = %&tmp_dir.display(),
+ PGDATABASE = &args.dbname,
+ PGUSER = &args.username,
+ DATABASE_URL = %&url,
+ "You can connect using the following parameters"
+ );
+
+ if let Some(symlink) = &args.symlink {
+ tokio::fs::symlink(&tmp_dir, symlink)
+ .await
+ .context("Failed to create symlink")?;
+ debug!(?symlink, "created symlink");
+ }
+
+ let mut wrapped_command = match args.command.split_first() {
+ None => None,
+ Some((head, tail)) => {
+ let mut optional_env = HashMap::new();
+ if let Some(ref username) = args.username {
+ optional_env.insert("PGUSER", username);
+ }
+ if let Some(ref dbname) = args.dbname {
+ optional_env.insert("PGDATABASE", dbname);
+ }
+ let command = Command::new(head)
+ .args(tail)
+ .env("PGHOST", &tmp_dir)
+ .env("DATABASE_URL", url.as_str()) // for SQLx with SQLX_OFFLINE=true
+ .envs(optional_env)
+ .kill_on_drop(true)
+ .spawn()
+ .context("Failed to spawn wrapped command")?;
+ debug!("spawned wrapped command");
+ Some(command)
+ }
+ };
+
+ tokio::select! {
+ _ = token.cancelled() => {
+ debug!("token cancelled");
+ debug!("shutting down child processes ...");
+ // Propagating SIGINT to child processes in case it was send to temp-postgres only.
+ // Halving shutdown timeout to avoid being aborted by parent task.
+ let shutdown_timeout = args.shutdown_timeout.checked_div(2).unwrap();
+ let (wrapped_command_result, postgres_result) = tokio::join!(
+ async {
+ if let Some(ref mut c) = wrapped_command {
+ interrupt_or_kill(c, shutdown_timeout).await
+ } else {
+ Ok(None)
+ }
+ },
+ interrupt_or_kill(&mut postgres, shutdown_timeout),
+ );
+ if wrapped_command.is_some() {
+ match wrapped_command_result {
+ // Exit code is probably 130
+ Ok(status) => debug!(?status, "successfully shut down wrapped command"),
+ Err(err) => warn!(error = ?err, "failed to shut down wrapped command"),
+ }
+ }
+ match postgres_result {
+ // Exit code is probably 130
+ Ok(status) => debug!(?status, "successfully shut down postgres command"),
+ Err(err) => warn!(error = ?err, "failed to shut down postgres command"),
+ }
+ Err(InterruptError{})?
+ },
+ r = postgres.wait() => {
+ debug!("postgres command finished");
+ debug!("shutting down wrapped command ...");
+ if let Some(ref mut c) = wrapped_command {
+ match interrupt_or_kill(c, args.shutdown_timeout).await {
+ // Exit code is probably 130
+ Ok(status) => debug!(?status, "successfully shut down wrapped command"),
+ Err(err) => warn!(error = ?err, "failed to shut down wrapped command"),
+ }
+ }
+ let status = r.context("failed to wait for postgres command")?;
+ if !status.success() {
+ bail!("postgres command exited error: {status}");
+ }
+ Ok(())
+ },
+ r = conditional_wait(&mut wrapped_command) => {
+ debug!("wrapped command finished");
+ debug!("shutting down postgres command ...");
+ match interrupt_or_kill(&mut postgres, args.shutdown_timeout).await {
+ // Exit code is probably 130
+ Ok(status) => debug!(?status, "successfully shut down postgres command"),
+ Err(err) => warn!(error = ?err, "failed to shut down postgres command"),
+ }
+ let status = r.context("failed to wait for wrapped command")?;
+ if !status.success() {
+ return Err(WrappedCommandError(status))?;
+ }
+ Ok(())
+ },
+ }
+}
+
+fn interrupt(child: &tokio::process::Child) -> Result<(), anyhow::Error> {
+ // std::os::unix::process::ChildExt::send_signal is nightly-only experimental
+ if let Some(id) = child.id() {
+ nix::sys::signal::kill(
+ nix::unistd::Pid::from_raw(id as i32),
+ nix::sys::signal::Signal::SIGINT,
+ )
+ .context("Failed to send SIGINT to process {id}")?;
+ }
+ Ok(())
+}
+
+async fn interrupt_or_kill(
+ child: &mut tokio::process::Child,
+ timeout: Duration,
+) -> Result<Option<ExitStatus>, anyhow::Error> {
+ interrupt(child)?;
+ tokio::select! {
+ r = child.wait() => {
+ r.map(Some).context("Failed to wait for child process")
+ },
+ _ = tokio::time::sleep(timeout) => {
+ child.kill().await.context("Failed to kill child process")?;
+ Ok(None)
+ }
+ }
+}
+
+async fn is_ready(tmp_dir: &OsStr, username: Option<&str>) -> Result<ExitStatus, anyhow::Error> {
+ let mut optional_isready_args = Vec::<&OsStr>::new();
+ if let Some(ref username) = username {
+ optional_isready_args.push("--username".as_ref());
+ optional_isready_args.push(username.as_ref());
+ }
+ Command::new("pg_isready")
+ .arg("--host")
+ .arg(tmp_dir)
+ .arg("--dbname")
+ .arg("dummy") // No database created yet
+ .arg("--timeout")
+ .arg("3") // Default
+ .args(optional_isready_args)
+ .stdin(Stdio::null())
+ .stdout(Stdio::null())
+ .stderr(Stdio::null())
+ .kill_on_drop(true)
+ .status()
+ .await
+ .context("Failed to spawn or wait for pg_isready command")
+}
+
+async fn conditional_wait(
+ child: &mut Option<tokio::process::Child>,
+) -> Result<std::process::ExitStatus, std::io::Error> {
+ match child {
+ None => std::future::pending().await,
+ Some(child) => child.wait().await,
+ }
+}
+
+#[derive(Debug)]
+struct WrappedCommandError(ExitStatus);
+
+impl std::fmt::Display for WrappedCommandError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "Wrapped command exited error: {}", self.0)
+ }
+}
+
+impl std::error::Error for WrappedCommandError {}
+
+#[derive(Debug)]
+struct InterruptError {}
+
+impl std::fmt::Display for InterruptError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "Interrupted")
+ }
+}
+
+impl std::error::Error for InterruptError {}
+
+fn version() -> String {
+ // https://doc.rust-lang.org/cargo/reference/environment-variables.html#environment-variables-cargo-sets-for-crates
+ let package_version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
+ let revision = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown");
+ let dirty = option_env!("VERGEN_GIT_DIRTY").and_then(|s| s.trim().parse::<bool>().ok());
+
+ let mut version = format!("{package_version}-{revision}");
+ if let Some(true) = dirty {
+ version.push_str("-dirty");
+ }
+ version
+}
+
+#[test]
+fn cli() {
+ use clap::CommandFactory;
+ Args::command().debug_assert();
+}
Generated by cgit. See skreutz.com for my tech blog and contact information.