//! Run the PostgreSQL server off a temporary data directory.
//!
//! See the [supported command-line arguments](Args).
use std::{
collections::HashMap,
ffi::{OsStr, OsString},
os::unix::fs::PermissionsExt,
path::PathBuf,
process::{ExitCode, ExitStatus, Stdio},
time::Duration,
};
use anyhow::{Context, bail};
use clap::{Parser, ValueHint, builder::NonEmptyStringValueParser};
use tokio::{process::Command, signal::unix::SignalKind, time::Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, trace, warn};
use url::Url;
// For clap.
macro_rules! version {
() => {
crate::version()
};
}
const EXTRA_HELP: &str = color_print::cstr!(
r#"Exit status:
Exits 0 on success, and >0 if an error occurs. Specifically, exits 2 on usage error, and 130 when interrupted. Propagates the exit status of the given command, if any.
"#
);
/// Command-line arguments.
#[derive(Debug, Clone, Parser)]
#[command(name = "temp-postgres", version, long_version = version!(), about, author, after_long_help = EXTRA_HELP, next_display_order = None)]
#[deny(clippy::missing_docs_in_private_items)]
struct Args {
/// Database name.
///
/// Defaults to the PostgreSQL user name.
#[arg(short = 'd', long, env = "PGDATABASE", value_parser = NonEmptyStringValueParser::new())]
dbname: Option,
/// PostgreSQL user name.
///
/// If not set, PostgreSQL uses the current operating system user.
#[arg(short = 'u', long, env = "PGUSER", value_parser = NonEmptyStringValueParser::new())]
username: Option,
/// Create a symbolic link to the temporary directory.
///
/// The given path must not exist.
///
/// The symbolic link will be created when the PostgreSQL server is ready.
/// It will be removed when the program ends if it (still) points to the temporary directory.
///
/// Example: Static client configuration.
///
/// ```sh
/// temp-postgres --symlink db
/// psql --host "$(realpath db)"
/// ```
// Clap ensures that arguments passed on the command-line are non-empty.
// Add a value parser when binding to an environment variable.
#[arg(long, value_name = "PATH", value_hint = ValueHint::FilePath)]
symlink: Option,
/// Log level.
#[arg(long, value_enum, default_value_t = LogLevel::Info)]
log_level: LogLevel,
/// Postgres startup timeout.
#[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")]
startup_timeout: Duration,
/// Graceful shutdown timeout.
///
/// `temp-postgres` performs a graceful shutdown upon SIGINT.
/// The signal is propagated to child processes, namely the `postgres` command, and the wrapped command, if any.
/// Child processes are killed (with SIGKILL) if they don't shut down in time.
#[arg(long, value_name = "DURATION", value_parser = humantime::parse_duration, default_value = "5s")]
shutdown_timeout: Duration,
/// Command to execute once the PostgreSQL server is ready.
///
/// The following environment variables will be passed to the command:
///
/// - PGHOST, the absolute path to the directory in which the UNIX domain socket file is stored
/// - PGDATABASE, the database name
/// - PGUSER, the PostgreSQL user name
/// - DATABASE_URL, a connection URI
///
/// See also the PostgreSQL documentation on environment variables and connection URIs:
///
/// -
/// -
///
/// Example: Wrap the `psql` command to connect to the temporary database once the server is ready.
///
/// ```sh
/// temp-postgres -- psql
/// ```
#[arg(last = true, value_hint = ValueHint::CommandWithArguments)]
command: Vec,
}
/// Log level.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum)]
enum LogLevel {
Off,
Error,
Warn,
Info,
Debug,
Trace,
}
impl From for tracing_subscriber::filter::LevelFilter {
fn from(value: LogLevel) -> Self {
match value {
LogLevel::Off => tracing_subscriber::filter::LevelFilter::OFF,
LogLevel::Error => tracing_subscriber::filter::LevelFilter::ERROR,
LogLevel::Warn => tracing_subscriber::filter::LevelFilter::WARN,
LogLevel::Info => tracing_subscriber::filter::LevelFilter::INFO,
LogLevel::Debug => tracing_subscriber::filter::LevelFilter::DEBUG,
LogLevel::Trace => tracing_subscriber::filter::LevelFilter::TRACE,
}
}
}
#[tokio::main]
async fn main() -> ExitCode {
let args = Args::parse(); // May exit 0 or 2. Respects `NO_COLOR`.
let result = async {
let mut subscriber = tracing_subscriber::fmt()
.with_max_level(args.log_level.to_owned())
.with_writer(std::io::stderr); // Enable user to capture standard output of wrapped command.
if std::env::var_os("NO_COLOR").is_some() {
subscriber = subscriber.with_ansi(false);
}
subscriber.init();
debug!(version = version(), "starting");
debug!(arguments = ?args, "parsed command-line arguments");
run(args).await
};
match result.await {
Ok(_) => ExitCode::SUCCESS,
Err(err) => {
// Relying on tracing to print error.
//eprintln!("Error: {err:?}");
if err.is::() {
ExitCode::from(130)
} else if let Some(err) = err.downcast_ref::() {
match err.0.code().and_then(|code| u8::try_from(code).ok()) {
Some(code) => ExitCode::from(code),
None => ExitCode::FAILURE,
}
} else {
ExitCode::FAILURE
}
}
}
}
#[tracing::instrument(level = "info", skip_all, err(Debug))]
async fn run(args: Args) -> Result<(), anyhow::Error> {
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
.context("Failed to register SIGINT handler")?;
debug!("registered SIGINT handler");
let tmp = tempfile::Builder::new()
.prefix("temp-postgres-")
.permissions(std::fs::Permissions::from_mode(0o700))
.tempdir()
.context("Failed to create temporary directory")?;
debug!(directory = %&tmp.path().display(), "created temporary directory");
let token = CancellationToken::new();
let mut server = tokio::spawn(serve(args.clone(), tmp.path().to_path_buf(), token.clone()));
debug!("spawned server task");
let result = tokio::select! {
_ = sigint.recv() => {
info!("received SIGINT, shutting down ...");
token.cancel();
tokio::select! {
_ = tokio::time::sleep(args.shutdown_timeout) => {
warn!("graceful shutdown timed out, aborting ...");
server.abort();
server.await?
}
r = &mut server => {
debug!("server task finished");
r?
}
}
}
r = &mut server => {
debug!("server task finished");
r?
},
};
if let Some(symlink) = &args.symlink {
if let Ok(target) = tokio::fs::canonicalize(&symlink).await
&& let Ok(tmp) = tokio::fs::canonicalize(&tmp.path()).await
&& target == tmp
{
match tokio::fs::remove_file(symlink).await {
Ok(()) => debug!(?symlink, "removed symlink"),
Err(err) => warn!(?symlink, error = ?err, "failed to remove symlink"),
}
} else {
warn!(?symlink, directory = %&tmp.path().display(), "keeping symlink because it doesn't point to temporary directory");
}
}
let tmp_dir = tmp.path().to_path_buf();
match tmp.close() {
Ok(()) => debug!(directory = %tmp_dir.display(), "removed temporary directory"),
Err(err) => {
warn!(directory = %tmp_dir.display(), error = ?err, "failed to remove temporary directory")
}
}
result
}
#[tracing::instrument(level = "info", skip_all, err)]
async fn serve(
args: Args,
tmp_dir: PathBuf,
token: CancellationToken,
) -> Result<(), anyhow::Error> {
// WORKAROUND: PostgreSQL expects a non-empty dbname when username is set.
let args = Args {
dbname: args.dbname.clone().or(args.username.clone()),
..args
};
// TODO: Check token.is_cancelled() before every command?
let mut optional_initdb_args = Vec::<&OsStr>::new();
if let Some(ref username) = args.username {
optional_initdb_args.push("--username".as_ref());
optional_initdb_args.push(username.as_ref());
}
let status = Command::new("initdb")
.arg("--pgdata")
.arg(&tmp_dir)
.args(optional_initdb_args)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.kill_on_drop(true) // In case this task gets aborted
.status()
.await
.context("Failed to spawn or wait for initdb command")?;
if !status.success() {
bail!("initdb command exited error: {status}");
}
debug!("initialized database");
let mut postgres = Command::new("postgres")
.arg("-h")
.arg("") // Don't bind to a TCP/IP address
.arg("-k")
.arg(&tmp_dir)
.arg("-D")
.arg(&tmp_dir)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.context("Failed to spawn postgres command")?;
debug!("spawned postgres command");
let now = Instant::now();
tokio::time::sleep(Duration::from_millis(50)).await;
for i in 1.. {
match is_ready(tmp_dir.as_ref(), args.username.as_deref()).await {
Ok(status) => {
if status.success() {
debug!("PostgreSQL server is ready");
break;
} else if status.code() == Some(1) || status.code() == Some(2) {
if now.elapsed() < args.startup_timeout {
trace!("retrying pg_isready command ...");
} else {
bail!("pg_isready timed out after {i} attempts");
}
} else {
bail!("pg_isready command exited error: {status}");
}
}
Err(err) => return Err(err),
}
tokio::time::sleep(Duration::from_millis(200)).await;
}
let mut optional_createdb_args = Vec::<&OsStr>::new();
if let Some(ref username) = args.username {
optional_createdb_args.push("--username".as_ref());
optional_createdb_args.push(username.as_ref());
}
if let Some(ref dbname) = args.dbname {
// Positional argument
optional_createdb_args.push(dbname.as_ref());
}
let status = Command::new("createdb")
.arg("--host")
.arg(&tmp_dir)
.arg("--no-password")
.args(optional_createdb_args)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.kill_on_drop(true)
.status()
.await
.context("Failed to spawn or wait for createdb command")?;
if !status.success() {
bail!("createdb command exited error: {status}");
}
debug!("Created PostgreSQL database");
let mut url = Url::parse("postgresql://")?;
if let Some(ref dbname) = args.dbname {
url.set_path(dbname);
}
url.query_pairs_mut()
.append_pair("host", &tmp_dir.display().to_string());
if let Some(ref username) = args.username {
url.query_pairs_mut().append_pair("user", username);
}
info!(
PGHOST = %&tmp_dir.display(),
PGDATABASE = &args.dbname,
PGUSER = &args.username,
DATABASE_URL = %&url,
"You can connect using the following parameters"
);
if let Some(symlink) = &args.symlink {
tokio::fs::symlink(&tmp_dir, symlink)
.await
.context("Failed to create symlink")?;
debug!(?symlink, "created symlink");
}
let mut wrapped_command = match args.command.split_first() {
None => None,
Some((head, tail)) => {
let mut optional_env = HashMap::new();
if let Some(ref username) = args.username {
optional_env.insert("PGUSER", username);
}
if let Some(ref dbname) = args.dbname {
optional_env.insert("PGDATABASE", dbname);
}
let command = Command::new(head)
.args(tail)
.env("PGHOST", &tmp_dir)
.env("DATABASE_URL", url.as_str()) // for SQLx with SQLX_OFFLINE=true
.envs(optional_env)
.kill_on_drop(true)
.spawn()
.context("Failed to spawn wrapped command")?;
debug!("spawned wrapped command");
Some(command)
}
};
tokio::select! {
_ = token.cancelled() => {
debug!("token cancelled");
debug!("shutting down child processes ...");
// Propagating SIGINT to child processes in case it was send to temp-postgres only.
// Halving shutdown timeout to avoid being aborted by parent task.
let shutdown_timeout = args.shutdown_timeout.checked_div(2).unwrap();
let (wrapped_command_result, postgres_result) = tokio::join!(
async {
if let Some(ref mut c) = wrapped_command {
interrupt_or_kill(c, shutdown_timeout).await
} else {
Ok(None)
}
},
interrupt_or_kill(&mut postgres, shutdown_timeout),
);
if wrapped_command.is_some() {
match wrapped_command_result {
// Exit code is probably 130
Ok(status) => debug!(?status, "successfully shut down wrapped command"),
Err(err) => warn!(error = ?err, "failed to shut down wrapped command"),
}
}
match postgres_result {
// Exit code is probably 130
Ok(status) => debug!(?status, "successfully shut down postgres command"),
Err(err) => warn!(error = ?err, "failed to shut down postgres command"),
}
Err(InterruptError{})?
},
r = postgres.wait() => {
debug!("postgres command finished");
debug!("shutting down wrapped command ...");
if let Some(ref mut c) = wrapped_command {
match interrupt_or_kill(c, args.shutdown_timeout).await {
// Exit code is probably 130
Ok(status) => debug!(?status, "successfully shut down wrapped command"),
Err(err) => warn!(error = ?err, "failed to shut down wrapped command"),
}
}
let status = r.context("failed to wait for postgres command")?;
if !status.success() {
bail!("postgres command exited error: {status}");
}
Ok(())
},
r = conditional_wait(&mut wrapped_command) => {
debug!("wrapped command finished");
debug!("shutting down postgres command ...");
match interrupt_or_kill(&mut postgres, args.shutdown_timeout).await {
// Exit code is probably 130
Ok(status) => debug!(?status, "successfully shut down postgres command"),
Err(err) => warn!(error = ?err, "failed to shut down postgres command"),
}
let status = r.context("failed to wait for wrapped command")?;
if !status.success() {
return Err(WrappedCommandError(status))?;
}
Ok(())
},
}
}
fn interrupt(child: &tokio::process::Child) -> Result<(), anyhow::Error> {
// std::os::unix::process::ChildExt::send_signal is nightly-only experimental
if let Some(id) = child.id() {
nix::sys::signal::kill(
nix::unistd::Pid::from_raw(id as i32),
nix::sys::signal::Signal::SIGINT,
)
.context("Failed to send SIGINT to process {id}")?;
}
Ok(())
}
async fn interrupt_or_kill(
child: &mut tokio::process::Child,
timeout: Duration,
) -> Result