diff --git a/Cargo.lock b/Cargo.lock index 07fca3720e7..2488776b1c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,6 +230,7 @@ dependencies = [ "atuin-daemon", "atuin-dotfiles", "atuin-history", + "atuin-scripts", "atuin-server", "atuin-server-postgres", "clap", @@ -254,6 +255,7 @@ dependencies = [ "serde", "serde_json", "sysinfo", + "tempfile", "time", "tiny-bip39", "tokio", @@ -393,6 +395,28 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "atuin-scripts" +version = "18.5.0-beta.2" +dependencies = [ + "atuin-client", + "atuin-common", + "eyre", + "minijinja", + "pretty_assertions", + "rmp", + "serde", + "serde_json", + "sql-builder", + "sqlx", + "tempfile", + "tokio", + "tracing", + "tracing-subscriber", + "typed-builder", + "uuid", +] + [[package]] name = "atuin-server" version = "18.5.0-beta.2" @@ -2522,6 +2546,15 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minijinja" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98642a6dfca91122779a307b77cd07a4aa951fbe32232aaf5bad9febc66be754" +dependencies = [ + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index 40ab7308a73..4b7bf08b235 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,9 @@ thiserror = "1.0" rustix = { version = "0.38.34", features = ["process", "fs"] } tower = "0.4" tracing = "0.1" +sql-builder = "3" +tempfile = { version = "3.19" } +minijinja = "2.9.0" [workspace.dependencies.tracing-subscriber] version = "0.3" diff --git a/crates/atuin-client/Cargo.toml b/crates/atuin-client/Cargo.toml index eccc13c2da3..a4f2fb24a52 100644 --- a/crates/atuin-client/Cargo.toml +++ b/crates/atuin-client/Cargo.toml @@ -43,7 +43,7 @@ minspan = "0.1.1" regex = "1.10.5" serde_regex = "1.1.0" fs-err = { workspace = true } -sql-builder = "3" +sql-builder = { workspace = true } memchr = "2.7" rmp = { version = "0.8.14" } typed-builder = { workspace = true } diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 91ccb6b8e7b..7af24d90cf7 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -30,6 +30,7 @@ pub const HOST_ID_FILENAME: &str = "host_id"; static EXAMPLE_CONFIG: &str = include_str!("../config.toml"); mod dotfiles; +mod scripts; #[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] pub enum SearchMode { @@ -515,6 +516,9 @@ pub struct Settings { #[serde(default)] pub theme: Theme, + + #[serde(default)] + pub scripts: scripts::Settings, } impl Settings { diff --git a/crates/atuin-client/src/settings/scripts.rs b/crates/atuin-client/src/settings/scripts.rs new file mode 100644 index 00000000000..e9d66c93aa4 --- /dev/null +++ b/crates/atuin-client/src/settings/scripts.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Settings { + pub database_path: String, +} + +impl Default for Settings { + fn default() -> Self { + let dir = atuin_common::utils::data_dir(); + let path = dir.join("scripts.db"); + + Self { + database_path: path.to_string_lossy().to_string(), + } + } +} diff --git a/crates/atuin-scripts/Cargo.toml b/crates/atuin-scripts/Cargo.toml new file mode 100644 index 00000000000..4d868757df1 --- /dev/null +++ b/crates/atuin-scripts/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "atuin-scripts" +edition = "2024" +version = { workspace = true } +description = "The scripts crate for Atuin" + +authors.workspace = true +rust-version.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +readme.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +atuin-client = { path = "../atuin-client", version = "18.5.0-beta.1" } +atuin-common = { path = "../atuin-common", version = "18.5.0-beta.1" } + +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +rmp = { version = "0.8.14" } +uuid = { workspace = true } +eyre = { workspace = true } +tokio = { workspace = true } +serde = { workspace = true } +typed-builder = { workspace = true } +pretty_assertions = { workspace = true } +sql-builder = { workspace = true } +sqlx = { workspace = true } +tempfile = { workspace = true } +minijinja = { workspace = true } +serde_json = { workspace = true } \ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql b/crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql new file mode 100644 index 00000000000..b2c5a36368a --- /dev/null +++ b/crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql @@ -0,0 +1,2 @@ +DROP TABLE scripts; +DROP TABLE script_tags; \ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql b/crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql new file mode 100644 index 00000000000..1b2f3688938 --- /dev/null +++ b/crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql @@ -0,0 +1,17 @@ +-- Add up migration script here +CREATE TABLE scripts ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT NOT NULL, + shebang TEXT NOT NULL, + script TEXT NOT NULL, + inserted_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) +); + +CREATE TABLE script_tags ( + id INTEGER PRIMARY KEY, + script_id TEXT NOT NULL, + tag TEXT NOT NULL +); + +CREATE UNIQUE INDEX idx_script_tags ON script_tags (script_id, tag); \ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql b/crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql new file mode 100644 index 00000000000..269b8cd9453 --- /dev/null +++ b/crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql @@ -0,0 +1,2 @@ +-- Add down migration script here +alter table scripts drop index name_uniq_idx; \ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql b/crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql new file mode 100644 index 00000000000..d2cdd02fef5 --- /dev/null +++ b/crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql @@ -0,0 +1,2 @@ +-- Add up migration script here +create unique index name_uniq_idx ON scripts(name); \ No newline at end of file diff --git a/crates/atuin-scripts/src/database.rs b/crates/atuin-scripts/src/database.rs new file mode 100644 index 00000000000..71da69ff91b --- /dev/null +++ b/crates/atuin-scripts/src/database.rs @@ -0,0 +1,358 @@ +use std::{path::Path, str::FromStr, time::Duration}; + +use atuin_common::utils; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use tokio::fs; +use tracing::debug; +use uuid::Uuid; + +use crate::store::script::Script; + +#[derive(Debug, Clone)] +pub struct Database { + pub pool: SqlitePool, +} + +impl Database { + pub async fn new(path: impl AsRef, timeout: f64) -> Result { + let path = path.as_ref(); + debug!("opening script sqlite database at {:?}", path); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Script sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir).await?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, s: &Script) -> Result<()> { + sqlx::query( + "insert or ignore into scripts(id, name, description, shebang, script) + values(?1, ?2, ?3, ?4, ?5)", + ) + .bind(s.id.to_string()) + .bind(s.name.as_str()) + .bind(s.description.as_str()) + .bind(s.shebang.as_str()) + .bind(s.script.as_str()) + .execute(&mut **tx) + .await?; + + for tag in s.tags.iter() { + sqlx::query( + "insert or ignore into script_tags(script_id, tag) + values(?1, ?2)", + ) + .bind(s.id.to_string()) + .bind(tag) + .execute(&mut **tx) + .await?; + } + + Ok(()) + } + + pub async fn save(&self, s: &Script) -> Result<()> { + debug!("saving script to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, s).await?; + tx.commit().await?; + + Ok(()) + } + + pub async fn save_bulk(&self, s: &[Script]) -> Result<()> { + debug!("saving scripts to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in s { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + fn query_script(row: SqliteRow) -> Script { + let id = row.get("id"); + let name = row.get("name"); + let description = row.get("description"); + let shebang = row.get("shebang"); + let script = row.get("script"); + + let id = Uuid::parse_str(id).unwrap(); + + Script { + id, + name, + description, + shebang, + script, + tags: vec![], + } + } + + fn query_script_tags(row: SqliteRow) -> String { + row.get("tag") + } + + #[allow(dead_code)] + async fn load(&self, id: &str) -> Result> { + debug!("loading script item {}", id); + + let res = sqlx::query("select * from scripts where id = ?1") + .bind(id) + .map(Self::query_script) + .fetch_optional(&self.pool) + .await?; + + // intentionally not joining, don't want to duplicate the script data in memory a whole bunch. + if let Some(mut script) = res { + let tags = sqlx::query("select tag from script_tags where script_id = ?1") + .bind(id) + .map(Self::query_script_tags) + .fetch_all(&self.pool) + .await?; + + script.tags = tags; + Ok(Some(script)) + } else { + Ok(None) + } + } + + pub async fn list(&self) -> Result> { + debug!("listing scripts"); + + let mut res = sqlx::query("select * from scripts") + .map(Self::query_script) + .fetch_all(&self.pool) + .await?; + + // Fetch all the tags for each script + for script in res.iter_mut() { + let tags = sqlx::query("select tag from script_tags where script_id = ?1") + .bind(script.id.to_string()) + .map(Self::query_script_tags) + .fetch_all(&self.pool) + .await?; + + script.tags = tags; + } + + Ok(res) + } + + pub async fn delete(&self, id: &str) -> Result<()> { + debug!("deleting script {}", id); + + sqlx::query("delete from scripts where id = ?1") + .bind(id) + .execute(&self.pool) + .await?; + + // delete all the tags for the script + sqlx::query("delete from script_tags where script_id = ?1") + .bind(id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn update(&self, s: &Script) -> Result<()> { + debug!("updating script {:?}", s); + + let mut tx = self.pool.begin().await?; + + // Update the script's base fields + sqlx::query("update scripts set name = ?1, description = ?2, shebang = ?3, script = ?4 where id = ?5") + .bind(s.name.as_str()) + .bind(s.description.as_str()) + .bind(s.shebang.as_str()) + .bind(s.script.as_str()) + .bind(s.id.to_string()) + .execute(&mut *tx) + .await?; + + // Delete all existing tags for this script + sqlx::query("delete from script_tags where script_id = ?1") + .bind(s.id.to_string()) + .execute(&mut *tx) + .await?; + + // Insert new tags + for tag in s.tags.iter() { + sqlx::query( + "insert or ignore into script_tags(script_id, tag) + values(?1, ?2)", + ) + .bind(s.id.to_string()) + .bind(tag) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + pub async fn get_by_name(&self, name: &str) -> Result> { + let res = sqlx::query("select * from scripts where name = ?1") + .bind(name) + .map(Self::query_script) + .fetch_optional(&self.pool) + .await?; + + let script = if let Some(mut script) = res { + let tags = sqlx::query("select tag from script_tags where script_id = ?1") + .bind(script.id.to_string()) + .map(Self::query_script_tags) + .fetch_all(&self.pool) + .await?; + + script.tags = tags; + Some(script) + } else { + None + }; + + Ok(script) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn test_list() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + let scripts = db.list().await.unwrap(); + assert_eq!(scripts.len(), 0); + + let script = Script::builder() + .name("test".to_string()) + .description("test".to_string()) + .shebang("test".to_string()) + .script("test".to_string()) + .build(); + + db.save(&script).await.unwrap(); + + let scripts = db.list().await.unwrap(); + assert_eq!(scripts.len(), 1); + assert_eq!(scripts[0].name, "test"); + } + + #[tokio::test] + async fn test_save_load() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let script = Script::builder() + .name("test name".to_string()) + .description("test description".to_string()) + .shebang("test shebang".to_string()) + .script("test script".to_string()) + .build(); + + db.save(&script).await.unwrap(); + + let loaded = db.load(&script.id.to_string()).await.unwrap().unwrap(); + + assert_eq!(loaded, script); + } + + #[tokio::test] + async fn test_save_bulk() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let scripts = vec![ + Script::builder() + .name("test name".to_string()) + .description("test description".to_string()) + .shebang("test shebang".to_string()) + .script("test script".to_string()) + .build(), + Script::builder() + .name("test name 2".to_string()) + .description("test description 2".to_string()) + .shebang("test shebang 2".to_string()) + .script("test script 2".to_string()) + .build(), + ]; + + db.save_bulk(&scripts).await.unwrap(); + + let loaded = db.list().await.unwrap(); + assert_eq!(loaded.len(), 2); + assert_eq!(loaded[0].name, "test name"); + assert_eq!(loaded[1].name, "test name 2"); + } + + #[tokio::test] + async fn test_delete() { + let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); + + let script = Script::builder() + .name("test name".to_string()) + .description("test description".to_string()) + .shebang("test shebang".to_string()) + .script("test script".to_string()) + .build(); + + db.save(&script).await.unwrap(); + + assert_eq!(db.list().await.unwrap().len(), 1); + db.delete(&script.id.to_string()).await.unwrap(); + + let loaded = db.list().await.unwrap(); + assert_eq!(loaded.len(), 0); + } +} diff --git a/crates/atuin-scripts/src/execution.rs b/crates/atuin-scripts/src/execution.rs new file mode 100644 index 00000000000..90f7c4ebcea --- /dev/null +++ b/crates/atuin-scripts/src/execution.rs @@ -0,0 +1,287 @@ +use crate::store::script::Script; +use eyre::Result; +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::process::Stdio; +use tempfile::NamedTempFile; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::mpsc; +use tokio::task; +use tracing::debug; + +// Helper function to build a complete script with shebang +pub fn build_executable_script(script: String, shebang: String) -> String { + if shebang.is_empty() { + // Default to bash if no shebang is provided + format!("#!/usr/bin/env bash\n{}", script) + } else if script.starts_with("#!") { + format!("{}\n{}", shebang, script) + } else { + format!("#!{}\n{}", shebang, script) + } +} + +/// Represents the communication channels for an interactive script +pub struct ScriptSession { + /// Channel to send input to the script + pub stdin_tx: mpsc::Sender, + /// Exit code of the process once it completes + pub exit_code_rx: mpsc::Receiver, +} + +impl ScriptSession { + /// Send input to the running script + pub async fn send_input(&self, input: String) -> Result<(), mpsc::error::SendError> { + self.stdin_tx.send(input).await + } + + /// Wait for the script to complete and get the exit code + pub async fn wait_for_exit(&mut self) -> Option { + self.exit_code_rx.recv().await + } +} + +fn setup_template(script: &Script) -> Result { + let mut env = minijinja::Environment::new(); + env.set_trim_blocks(true); + env.add_template("script", script.script.as_str())?; + + Ok(env) +} + +/// Template a script with the given context +pub fn template_script( + script: &Script, + context: &HashMap, +) -> Result { + let env = setup_template(script)?; + let template = env.get_template("script")?; + let rendered = template.render(context)?; + + Ok(rendered) +} + +/// Get the variables that need to be templated in a script +pub fn template_variables(script: &Script) -> Result> { + let env = setup_template(script)?; + let template = env.get_template("script")?; + + Ok(template.undeclared_variables(true)) +} + +/// Execute a script interactively, allowing for ongoing stdin/stdout interaction +pub async fn execute_script_interactive( + script: String, + shebang: String, +) -> Result> { + // Create a temporary file for the script + let temp_file = NamedTempFile::new()?; + let temp_path = temp_file.path().to_path_buf(); + + debug!("creating temp file at {}", temp_path.display()); + + // Extract interpreter from shebang for fallback execution + let interpreter = if !shebang.is_empty() { + shebang.trim_start_matches("#!").trim().to_string() + } else { + "/usr/bin/env bash".to_string() + }; + + // Write script content to the temp file, including the shebang + let full_script_content = build_executable_script(script.clone(), shebang.clone()); + + debug!("writing script content to temp file"); + tokio::fs::write(&temp_path, &full_script_content).await?; + + // Make it executable on Unix systems + #[cfg(unix)] + { + debug!("making script executable"); + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&temp_path)?.permissions(); + perms.set_mode(0o755); + fs::set_permissions(&temp_path, perms)?; + } + + // Store the temp_file to prevent it from being dropped + // This ensures it won't be deleted while the script is running + let _keep_temp_file = temp_file; + + debug!("attempting direct script execution"); + let mut child_result = tokio::process::Command::new(temp_path.to_str().unwrap()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn(); + + // If direct execution fails, try using the interpreter + if let Err(e) = &child_result { + debug!("direct execution failed: {}, trying with interpreter", e); + + // When falling back to interpreter, remove the shebang from the file + // Some interpreters don't handle scripts with shebangs well + debug!("writing script content without shebang for interpreter execution"); + tokio::fs::write(&temp_path, &script).await?; + + // Parse the interpreter command + let parts: Vec<&str> = interpreter.split_whitespace().collect(); + if !parts.is_empty() { + let mut cmd = tokio::process::Command::new(parts[0]); + + // Add any interpreter args + for i in parts.iter().skip(1) { + cmd.arg(i); + } + + // Add the script path + cmd.arg(temp_path.to_str().unwrap()); + + // Try with the interpreter + child_result = cmd + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn(); + } + } + + // If it still fails, return the error + let mut child = match child_result { + Ok(child) => child, + Err(e) => { + return Err(format!("Failed to execute script: {}", e).into()); + } + }; + + // Get handles to stdin, stdout, stderr + let mut stdin = child + .stdin + .take() + .ok_or_else(|| "Failed to open child process stdin".to_string())?; + let stdout = child + .stdout + .take() + .ok_or_else(|| "Failed to open child process stdout".to_string())?; + let stderr = child + .stderr + .take() + .ok_or_else(|| "Failed to open child process stderr".to_string())?; + + // Create channels for the interactive session + let (stdin_tx, mut stdin_rx) = mpsc::channel::(32); + let (exit_code_tx, exit_code_rx) = mpsc::channel::(1); + + // handle user stdin + debug!("spawning stdin handler"); + tokio::spawn(async move { + while let Some(input) = stdin_rx.recv().await { + if let Err(e) = stdin.write_all(input.as_bytes()).await { + eprintln!("Error writing to stdin: {}", e); + break; + } + if let Err(e) = stdin.flush().await { + eprintln!("Error flushing stdin: {}", e); + break; + } + } + // when the channel closes (sender dropped), we let stdin close naturally + }); + + // handle stdout + debug!("spawning stdout handler"); + let stdout_handle = task::spawn(async move { + let mut stdout_reader = BufReader::new(stdout); + let mut buffer = [0u8; 1024]; + let mut stdout_writer = tokio::io::stdout(); + + loop { + match stdout_reader.read(&mut buffer).await { + Ok(0) => break, // End of stdout + Ok(n) => { + if let Err(e) = stdout_writer.write_all(&buffer[0..n]).await { + eprintln!("Error writing to stdout: {}", e); + break; + } + if let Err(e) = stdout_writer.flush().await { + eprintln!("Error flushing stdout: {}", e); + break; + } + } + Err(e) => { + eprintln!("Error reading from process stdout: {}", e); + break; + } + } + } + }); + + // Process stderr in a separate task + debug!("spawning stderr handler"); + let stderr_handle = task::spawn(async move { + let mut stderr_reader = BufReader::new(stderr); + let mut buffer = [0u8; 1024]; + let mut stderr_writer = tokio::io::stderr(); + + loop { + match stderr_reader.read(&mut buffer).await { + Ok(0) => break, // End of stderr + Ok(n) => { + if let Err(e) = stderr_writer.write_all(&buffer[0..n]).await { + eprintln!("Error writing to stderr: {}", e); + break; + } + if let Err(e) = stderr_writer.flush().await { + eprintln!("Error flushing stderr: {}", e); + break; + } + } + Err(e) => { + eprintln!("Error reading from process stderr: {}", e); + break; + } + } + } + }); + + // Spawn a task to wait for the child process to complete + debug!("spawning exit code handler"); + let _keep_temp_file_clone = _keep_temp_file; + tokio::spawn(async move { + // Keep the temp file alive until the process completes + let _temp_file_ref = _keep_temp_file_clone; + + // Wait for the child process to complete + let status = match child.wait().await { + Ok(status) => { + debug!("Process exited with status: {:?}", status); + status + } + Err(e) => { + eprintln!("Error waiting for child process: {}", e); + // Send a default error code + let _ = exit_code_tx.send(-1).await; + return; + } + }; + + // Wait for stdout/stderr tasks to complete + if let Err(e) = stdout_handle.await { + eprintln!("Error joining stdout task: {}", e); + } + + if let Err(e) = stderr_handle.await { + eprintln!("Error joining stderr task: {}", e); + } + + // Send the exit code + let exit_code = status.code().unwrap_or(-1); + debug!("Sending exit code: {}", exit_code); + let _ = exit_code_tx.send(exit_code).await; + }); + + // Return the communication channels as a ScriptSession + Ok(ScriptSession { + stdin_tx, + exit_code_rx, + }) +} diff --git a/crates/atuin-scripts/src/lib.rs b/crates/atuin-scripts/src/lib.rs new file mode 100644 index 00000000000..c79c7089faa --- /dev/null +++ b/crates/atuin-scripts/src/lib.rs @@ -0,0 +1,4 @@ +pub mod database; +pub mod execution; +pub mod settings; +pub mod store; diff --git a/crates/atuin-scripts/src/settings.rs b/crates/atuin-scripts/src/settings.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/crates/atuin-scripts/src/settings.rs @@ -0,0 +1 @@ + diff --git a/crates/atuin-scripts/src/store.rs b/crates/atuin-scripts/src/store.rs new file mode 100644 index 00000000000..ba7a1ca148c --- /dev/null +++ b/crates/atuin-scripts/src/store.rs @@ -0,0 +1,109 @@ +use eyre::{Result, bail}; + +use atuin_client::record::sqlite_store::SqliteStore; +use atuin_client::record::{encryption::PASETO_V4, store::Store}; +use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx}; +use record::ScriptRecord; +use script::{SCRIPT_TAG, SCRIPT_VERSION, Script}; + +use crate::database::Database; + +pub mod record; +pub mod script; + +#[derive(Debug, Clone)] +pub struct ScriptStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +impl ScriptStore { + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { + ScriptStore { + store, + host_id, + encryption_key, + } + } + + async fn push_record(&self, record: ScriptRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .store + .last(self.host_id, SCRIPT_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(SCRIPT_VERSION.to_string()) + .tag(SCRIPT_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.store + .push(&record.encrypt::(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + pub async fn create(&self, script: Script) -> Result<()> { + let record = ScriptRecord::Create(script); + self.push_record(record).await?; + Ok(()) + } + + pub async fn update(&self, script: Script) -> Result<()> { + let record = ScriptRecord::Update(script); + self.push_record(record).await?; + Ok(()) + } + + pub async fn delete(&self, script_id: uuid::Uuid) -> Result<()> { + let record = ScriptRecord::Delete(script_id); + self.push_record(record).await?; + Ok(()) + } + + pub async fn scripts(&self) -> Result> { + let records = self.store.all_tagged(SCRIPT_TAG).await?; + let mut ret = Vec::with_capacity(records.len()); + + for record in records.into_iter() { + let script = match record.version.as_str() { + SCRIPT_VERSION => { + let decrypted = record.decrypt::(&self.encryption_key)?; + + ScriptRecord::deserialize(&decrypted.data, SCRIPT_VERSION) + } + version => bail!("unknown history version {version:?}"), + }?; + + ret.push(script); + } + + Ok(ret) + } + + pub async fn build(&self, database: Database) -> Result<()> { + // Get all the scripts from the database - they are already sorted by timestamp + let scripts = self.scripts().await?; + + for script in scripts { + match script { + ScriptRecord::Create(script) => { + database.save(&script).await?; + } + ScriptRecord::Update(script) => database.update(&script).await?, + ScriptRecord::Delete(id) => database.delete(&id.to_string()).await?, + } + } + + Ok(()) + } +} diff --git a/crates/atuin-scripts/src/store/record.rs b/crates/atuin-scripts/src/store/record.rs new file mode 100644 index 00000000000..4c925be3a2f --- /dev/null +++ b/crates/atuin-scripts/src/store/record.rs @@ -0,0 +1,215 @@ +use atuin_common::record::DecryptedData; +use eyre::{Result, eyre}; +use uuid::Uuid; + +use crate::store::script::SCRIPT_VERSION; + +use super::script::Script; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ScriptRecord { + Create(Script), + Update(Script), + Delete(Uuid), +} + +impl ScriptRecord { + pub fn serialize(&self) -> Result { + use rmp::encode; + + let mut output = vec![]; + + match self { + ScriptRecord::Create(script) => { + // 0 -> a script create + encode::write_u8(&mut output, 0)?; + + let bytes = script.serialize()?; + + encode::write_bin(&mut output, &bytes.0)?; + } + + ScriptRecord::Delete(id) => { + // 1 -> a script delete + encode::write_u8(&mut output, 1)?; + encode::write_str(&mut output, id.to_string().as_str())?; + } + + ScriptRecord::Update(script) => { + // 2 -> a script update + encode::write_u8(&mut output, 2)?; + let bytes = script.serialize()?; + encode::write_bin(&mut output, &bytes.0)?; + } + }; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(data: &DecryptedData, version: &str) -> Result { + use rmp::decode; + + fn error_report(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match version { + SCRIPT_VERSION => { + let mut bytes = decode::Bytes::new(&data.0); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // create + 0 => { + // written by encode::write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + let script = Script::deserialize(bytes.remaining_slice())?; + Ok(ScriptRecord::Create(script)) + } + + // delete + 1 => { + let bytes = bytes.remaining_slice(); + let (id, _) = decode::read_str_from_slice(bytes).map_err(error_report)?; + Ok(ScriptRecord::Delete(Uuid::parse_str(id)?)) + } + + // update + 2 => { + // written by encode::write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + let script = Script::deserialize(bytes.remaining_slice())?; + Ok(ScriptRecord::Update(script)) + } + + _ => Err(eyre!("unknown script record type {record_type}")), + } + } + _ => Err(eyre!("unknown version {version:?}")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialize_create() { + let script = Script::builder() + .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap()) + .name("test".to_string()) + .description("test".to_string()) + .shebang("test".to_string()) + .tags(vec!["test".to_string()]) + .script("test".to_string()) + .build(); + + let record = ScriptRecord::Create(script); + + let serialized = record.serialize().unwrap(); + + assert_eq!( + serialized.0, + vec![ + 204, 0, 196, 65, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, + 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, + 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, + 116, 145, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116 + ] + ); + } + + #[test] + fn test_serialize_delete() { + let record = ScriptRecord::Delete( + uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(), + ); + + let serialized = record.serialize().unwrap(); + + assert_eq!( + serialized.0, + vec![ + 204, 1, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, + 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54 + ] + ); + } + + #[test] + fn test_serialize_update() { + let script = Script::builder() + .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap()) + .name(String::from("test")) + .description(String::from("test")) + .shebang(String::from("test")) + .tags(vec![String::from("test"), String::from("test2")]) + .script(String::from("test")) + .build(); + + let record = ScriptRecord::Update(script); + + let serialized = record.serialize().unwrap(); + + assert_eq!( + serialized.0, + vec![ + 204, 2, 196, 71, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, + 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, + 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, + 116, 146, 164, 116, 101, 115, 116, 165, 116, 101, 115, 116, 50, 164, 116, 101, 115, + 116 + ], + ); + } + + #[test] + fn test_serialize_deserialize_create() { + let script = Script::builder() + .name("test".to_string()) + .description("test".to_string()) + .shebang("test".to_string()) + .tags(vec!["test".to_string()]) + .script("test".to_string()) + .build(); + + let record = ScriptRecord::Create(script); + + let serialized = record.serialize().unwrap(); + let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap(); + + assert_eq!(record, deserialized); + } + + #[test] + fn test_serialize_deserialize_delete() { + let record = ScriptRecord::Delete( + uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(), + ); + + let serialized = record.serialize().unwrap(); + let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap(); + + assert_eq!(record, deserialized); + } + + #[test] + fn test_serialize_deserialize_update() { + let script = Script::builder() + .name("test".to_string()) + .description("test".to_string()) + .shebang("test".to_string()) + .tags(vec!["test".to_string()]) + .script("test".to_string()) + .build(); + + let record = ScriptRecord::Update(script); + + let serialized = record.serialize().unwrap(); + let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap(); + + assert_eq!(record, deserialized); + } +} diff --git a/crates/atuin-scripts/src/store/script.rs b/crates/atuin-scripts/src/store/script.rs new file mode 100644 index 00000000000..af180320844 --- /dev/null +++ b/crates/atuin-scripts/src/store/script.rs @@ -0,0 +1,151 @@ +use atuin_common::record::DecryptedData; +use eyre::{Result, bail, ensure}; +use uuid::Uuid; + +use rmp::{ + decode::{self, Bytes}, + encode, +}; +use typed_builder::TypedBuilder; + +pub const SCRIPT_VERSION: &str = "v0"; +pub const SCRIPT_TAG: &str = "script"; +pub const SCRIPT_LEN: usize = 20000; // 20kb max total len + +#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)] +/// A script is a set of commands that can be run, with the specified shebang +pub struct Script { + /// The id of the script + #[builder(default = uuid::Uuid::new_v4())] + pub id: Uuid, + + /// The name of the script + pub name: String, + + /// The description of the script + #[builder(default = String::new())] + pub description: String, + + /// The interpreter of the script + #[builder(default = String::new())] + pub shebang: String, + + /// The tags of the script + #[builder(default = Vec::new())] + pub tags: Vec, + + /// The script content + pub script: String, +} + +impl Script { + pub fn serialize(&self) -> Result { + // sort the tags first, to ensure consistent ordering + let mut tags = self.tags.clone(); + tags.sort(); + + let mut output = vec![]; + + encode::write_array_len(&mut output, 6)?; + encode::write_str(&mut output, &self.id.to_string())?; + encode::write_str(&mut output, &self.name)?; + encode::write_str(&mut output, &self.description)?; + encode::write_str(&mut output, &self.shebang)?; + encode::write_array_len(&mut output, self.tags.len() as u32)?; + + for tag in &tags { + encode::write_str(&mut output, tag)?; + } + + encode::write_str(&mut output, &self.script)?; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(bytes: &[u8]) -> Result { + let mut bytes = decode::Bytes::new(bytes); + let nfields = decode::read_array_len(&mut bytes).unwrap(); + + ensure!(nfields == 6, "too many entries in v0 script record"); + + let bytes = bytes.remaining_slice(); + + let (id, bytes) = decode::read_str_from_slice(bytes).unwrap(); + let (name, bytes) = decode::read_str_from_slice(bytes).unwrap(); + let (description, bytes) = decode::read_str_from_slice(bytes).unwrap(); + let (shebang, bytes) = decode::read_str_from_slice(bytes).unwrap(); + + let mut bytes = Bytes::new(bytes); + let tags_len = decode::read_array_len(&mut bytes).unwrap(); + + let mut bytes = bytes.remaining_slice(); + + let mut tags = Vec::new(); + for _ in 0..tags_len { + let (tag, remaining) = decode::read_str_from_slice(bytes).unwrap(); + tags.push(tag.to_owned()); + bytes = remaining; + } + + let (script, bytes) = decode::read_str_from_slice(bytes).unwrap(); + + if !bytes.is_empty() { + bail!("trailing bytes in encoded script record. malformed") + } + + Ok(Script { + id: Uuid::parse_str(id).unwrap(), + name: name.to_owned(), + description: description.to_owned(), + shebang: shebang.to_owned(), + tags, + script: script.to_owned(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialize() { + let script = Script { + id: uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(), + name: "test".to_string(), + description: "test".to_string(), + shebang: "test".to_string(), + tags: vec!["test".to_string()], + script: "test".to_string(), + }; + + let serialized = script.serialize().unwrap(); + assert_eq!( + serialized.0, + vec![ + 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, 56, + 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54, 164, + 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 145, 164, + 116, 101, 115, 116, 164, 116, 101, 115, 116 + ] + ); + } + + #[test] + fn test_serialize_deserialize() { + let script = Script { + id: uuid::Uuid::new_v4(), + name: "test".to_string(), + description: "test".to_string(), + shebang: "test".to_string(), + tags: vec!["test".to_string()], + script: "test".to_string(), + }; + + let serialized = script.serialize().unwrap(); + + let deserialized = Script::deserialize(&serialized.0).unwrap(); + + assert_eq!(script, deserialized); + } +} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 8aa10dff657..881dc659dad 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -49,6 +49,7 @@ atuin-common = { path = "../atuin-common", version = "18.5.0-beta.2" } atuin-dotfiles = { path = "../atuin-dotfiles", version = "18.5.0-beta.2" } atuin-history = { path = "../atuin-history", version = "18.5.0-beta.2" } atuin-daemon = { path = "../atuin-daemon", version = "18.5.0-beta.2", optional = true, default-features = false } +atuin-scripts = { path = "../atuin-scripts", version = "18.5.0-beta.2" } log = { workspace = true } time = { workspace = true } @@ -80,6 +81,7 @@ tracing-subscriber = { workspace = true } uuid = { workspace = true } sysinfo = "0.30.7" regex = "1.10.5" +tempfile = { workspace = true } [target.'cfg(any(target_os = "windows", target_os = "macos"))'.dependencies] arboard = { version = "3.4", optional = true } diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index 723bb974e85..330fef0cb96 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -25,6 +25,7 @@ mod import; mod info; mod init; mod kv; +mod scripts; mod search; mod stats; mod store; @@ -67,6 +68,10 @@ pub enum Cmd { #[command(subcommand)] Dotfiles(dotfiles::Cmd), + /// Manage your scripts with Atuin + #[command(subcommand)] + Scripts(scripts::Cmd), + /// Print Atuin's shell init script #[command()] Init(init::Cmd), @@ -159,6 +164,8 @@ impl Cmd { Self::Dotfiles(dotfiles) => dotfiles.run(&settings, sqlite_store).await, + Self::Scripts(scripts) => scripts.run(&settings, sqlite_store, &db).await, + Self::Info => { info::run(&settings); Ok(()) diff --git a/crates/atuin/src/command/client/scripts.rs b/crates/atuin/src/command/client/scripts.rs new file mode 100644 index 00000000000..3993d6b8590 --- /dev/null +++ b/crates/atuin/src/command/client/scripts.rs @@ -0,0 +1,572 @@ +use std::collections::HashMap; +use std::collections::HashSet; +use std::path::PathBuf; + +use atuin_scripts::execution::template_script; +use atuin_scripts::{ + execution::{build_executable_script, execute_script_interactive, template_variables}, + store::{ScriptStore, script::Script}, +}; +use clap::{Parser, Subcommand}; +use eyre::{Result, bail}; +use tempfile::NamedTempFile; + +use atuin_client::{database::Database, record::sqlite_store::SqliteStore, settings::Settings}; +use tracing::debug; + +#[derive(Parser, Debug)] +pub struct NewScript { + pub name: String, + + #[arg(short, long)] + pub description: Option, + + #[arg(short, long)] + pub tags: Vec, + + #[arg(short, long)] + pub shebang: Option, + + #[arg(long)] + pub script: Option, + + #[allow(clippy::option_option)] + #[arg(long)] + /// Use the last command as the script content + /// Optionally specify a number to use the last N commands + pub last: Option>, + + #[arg(long)] + /// Skip opening editor when using --last + pub no_edit: bool, +} + +#[derive(Parser, Debug)] +pub struct Run { + pub name: String, + + /// Specify template variables in the format KEY=VALUE + /// Example: -v name=John -v greeting="Hello there" + #[arg(short, long = "var")] + pub var: Vec, +} + +#[derive(Parser, Debug)] +pub struct List {} + +#[derive(Parser, Debug)] +pub struct Get { + pub name: String, + + #[arg(short, long)] + /// Display only the executable script with shebang + pub script: bool, +} + +#[derive(Parser, Debug)] +pub struct Edit { + pub name: String, + + #[arg(short, long)] + pub description: Option, + + /// Replace all existing tags with these new tags + #[arg(short, long)] + pub tags: Vec, + + /// Remove all tags from the script + #[arg(long)] + pub no_tags: bool, + + /// Rename the script + #[arg(long)] + pub rename: Option, + + #[arg(short, long)] + pub shebang: Option, + + #[arg(long)] + pub script: Option, + + #[allow(clippy::struct_field_names)] + /// Skip opening editor + #[arg(long)] + pub no_edit: bool, +} + +#[derive(Parser, Debug)] +pub struct Delete { + pub name: String, + + #[arg(short, long)] + pub force: bool, +} + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + New(NewScript), + Run(Run), + List(List), + + Get(Get), + Edit(Edit), + Delete(Delete), +} + +impl Cmd { + // Helper function to open an editor with optional initial content + fn open_editor(initial_content: Option<&str>) -> Result { + // Create a temporary file + let temp_file = NamedTempFile::new()?; + let path = temp_file.into_temp_path(); + + // Write initial content to the temp file if provided + if let Some(content) = initial_content { + std::fs::write(&path, content)?; + } + + // Open the file in the user's preferred editor + let editor = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); + let status = std::process::Command::new(editor).arg(&path).status()?; + if !status.success() { + bail!("failed to open editor"); + } + + // Read back the edited content + let content = std::fs::read_to_string(&path)?; + path.close()?; + + Ok(content) + } + + // Helper function to execute a script and manage stdin/stdout/stderr + async fn execute_script(script_content: String, shebang: String) -> Result { + let mut session = execute_script_interactive(script_content, shebang) + .await + .expect("failed to execute script"); + + // Create a channel to signal when the process exits + let (exit_tx, mut exit_rx) = tokio::sync::oneshot::channel(); + + // Set up a task to read from stdin and forward to the script + let sender = session.stdin_tx.clone(); + let stdin_task = tokio::spawn(async move { + use tokio::io::AsyncReadExt; + use tokio::select; + + let stdin = tokio::io::stdin(); + let mut reader = tokio::io::BufReader::new(stdin); + let mut buffer = vec![0u8; 1024]; // Read in chunks for efficiency + + loop { + // Use select to either read from stdin or detect when the process exits + select! { + // Check if the script process has exited + _ = &mut exit_rx => { + break; + } + // Try to read from stdin + read_result = reader.read(&mut buffer) => { + match read_result { + Ok(0) => break, // EOF + Ok(n) => { + // Convert the bytes to a string and forward to script + let input = String::from_utf8_lossy(&buffer[0..n]).to_string(); + if let Err(e) = sender.send(input).await { + eprintln!("Error sending input to script: {e}"); + break; + } + }, + Err(e) => { + eprintln!("Error reading from stdin: {e}"); + break; + } + } + } + } + } + }); + + // Wait for the script to complete + let exit_code = session.wait_for_exit().await; + + // Signal the stdin task to stop + let _ = exit_tx.send(()); + let _ = stdin_task.await; + + let code = exit_code.unwrap_or(-1); + if code != 0 { + eprintln!("Script exited with code {code}"); + } + + Ok(code) + } + + async fn handle_new_script( + settings: &Settings, + new_script: NewScript, + script_store: ScriptStore, + script_db: atuin_scripts::database::Database, + history_db: &impl Database, + ) -> Result<()> { + let script_content = if let Some(count_opt) = new_script.last { + // Get the last N commands from history, plus 1 to exclude the command that runs this script + let count = count_opt.unwrap_or(1) + 1; // Add 1 to the count to exclude the current command + let context = atuin_client::database::current_context(); + + // Get the last N+1 commands, filtering by the default mode + let filters = [settings.default_filter_mode()]; + + let mut history = history_db + .list(&filters, &context, Some(count), false, false) + .await?; + + // Reverse to get chronological order + history.reverse(); + + // Skip the most recent command (which would be the atuin scripts new command itself) + if !history.is_empty() { + history.pop(); // Remove the most recent command + } + + // Format the commands into a script + let commands: Vec = history.iter().map(|h| h.command.clone()).collect(); + + if commands.is_empty() { + bail!("No commands found in history"); + } + + let script_text = commands.join("\n"); + + // Only open editor if --no-edit is not specified + if new_script.no_edit { + Some(script_text) + } else { + // Open the editor with the commands pre-loaded + Some(Self::open_editor(Some(&script_text))?) + } + } else if let Some(script_path) = new_script.script { + let script_content = std::fs::read_to_string(script_path)?; + Some(script_content) + } else { + // Open editor with empty file + Some(Self::open_editor(None)?) + }; + + let script = Script::builder() + .name(new_script.name) + .description(new_script.description.unwrap_or_default()) + .shebang(new_script.shebang.unwrap_or_default()) + .tags(new_script.tags) + .script(script_content.unwrap_or_default()) + .build(); + + script_store.create(script).await?; + + script_store.build(script_db).await?; + + Ok(()) + } + + async fn handle_run( + _settings: &Settings, + run: Run, + script_db: atuin_scripts::database::Database, + ) -> Result<()> { + let script = script_db.get_by_name(&run.name).await?; + + if let Some(script) = script { + // Get variables used in the template + let variables = template_variables(&script)?; + + // Create a hashmap to store variable values + let mut variable_values: HashMap = HashMap::new(); + + // Parse variables from command-line arguments first + for var_str in &run.var { + if let Some((key, value)) = var_str.split_once('=') { + // Add to variable values + variable_values.insert( + key.to_string(), + serde_json::Value::String(value.to_string()), + ); + debug!("Using CLI variable: {}={}", key, value); + } else { + eprintln!("Warning: Ignoring malformed variable specification: {var_str}"); + eprintln!("Variables should be specified as KEY=VALUE"); + } + } + + // Collect variables that are still needed (not specified via CLI) + let remaining_vars: HashSet = variables + .into_iter() + .filter(|var| !variable_values.contains_key(var)) + .collect(); + + // If there are variables in the template that weren't specified on the command line, prompt for them + if !remaining_vars.is_empty() { + println!("This script contains template variables that need values:"); + + let stdin = std::io::stdin(); + let mut input = String::new(); + + for var in remaining_vars { + input.clear(); + + println!("Enter value for '{var}': "); + + if stdin.read_line(&mut input).is_err() { + eprintln!("Failed to read input for variable '{var}'"); + // Provide an empty string as fallback + variable_values.insert(var, serde_json::Value::String(String::new())); + continue; + } + + let value = input.trim().to_string(); + variable_values.insert(var, serde_json::Value::String(value)); + } + } + + let final_script = if variable_values.is_empty() { + // No variables to template, just use the original script + script.script.clone() + } else { + // If we have variables, we need to template the script + debug!("Templating script with variables: {:?}", variable_values); + template_script(&script, &variable_values)? + }; + + // Execute the script (either templated or original) + Self::execute_script(final_script, script.shebang.clone()).await?; + } else { + bail!("script not found"); + } + Ok(()) + } + + async fn handle_list( + _settings: &Settings, + _list: List, + script_db: atuin_scripts::database::Database, + ) -> Result<()> { + let scripts = script_db.list().await?; + + if scripts.is_empty() { + println!("No scripts found"); + } else { + println!("Available scripts:"); + for script in scripts { + if script.tags.is_empty() { + println!("- {} ", script.name); + } else { + println!("- {} [tags: {}]", script.name, script.tags.join(", ")); + } + + // Print description if it's not empty + if !script.description.is_empty() { + println!(" Description: {}", script.description); + } + } + } + + Ok(()) + } + + async fn handle_get( + _settings: &Settings, + get: Get, + script_db: atuin_scripts::database::Database, + ) -> Result<()> { + let script = script_db.get_by_name(&get.name).await?; + + if let Some(script) = script { + if get.script { + // Just print the executable script with shebang + print!( + "{}", + build_executable_script(script.script.clone(), script.shebang) + ); + return Ok(()); + } + + // Create a YAML representation of the script + println!("---"); + println!("name: {}", script.name); + println!("id: {}", script.id); + + if script.description.is_empty() { + println!("description: \"\""); + } else { + println!("description: |"); + // Indent multiline descriptions properly for YAML + for line in script.description.lines() { + println!(" {line}"); + } + } + + if script.tags.is_empty() { + println!("tags: []"); + } else { + println!("tags:"); + for tag in &script.tags { + println!(" - {tag}"); + } + } + + println!("shebang: {}", script.shebang); + + println!("script: |"); + // Indent the script content for proper YAML multiline format + for line in script.script.lines() { + println!(" {line}"); + } + + Ok(()) + } else { + bail!("script '{}' not found", get.name); + } + } + + async fn handle_edit( + _settings: &Settings, + edit: Edit, + script_store: ScriptStore, + script_db: atuin_scripts::database::Database, + ) -> Result<()> { + debug!("editing script {:?}", edit); + // Find the existing script + let existing_script = script_db.get_by_name(&edit.name).await?; + debug!("existing script {:?}", existing_script); + + if let Some(mut script) = existing_script { + // Update the script with new values if provided + if let Some(description) = edit.description { + script.description = description; + } + + // Handle renaming if requested + if let Some(new_name) = edit.rename { + // Check if a script with the new name already exists + if (script_db.get_by_name(&new_name).await?).is_some() { + bail!("A script named '{}' already exists", new_name); + } + + // Update the name + script.name = new_name; + } + + // Handle tag updates with priority: + // 1. If --no-tags is provided, clear all tags + // 2. If --tags is provided, replace all tags + // 3. If neither is provided, tags remain unchanged + if edit.no_tags { + // Clear all tags + script.tags.clear(); + } else if !edit.tags.is_empty() { + // Replace all tags + script.tags = edit.tags; + } + // If none of the above conditions are met, tags remain unchanged + + if let Some(shebang) = edit.shebang { + script.shebang = shebang; + } + + // Handle script content update + let script_content = if let Some(script_path) = edit.script { + // Load script from provided file + std::fs::read_to_string(script_path)? + } else if !edit.no_edit { + // Open the script in editor for interactive editing if --no-edit is not specified + Self::open_editor(Some(&script.script))? + } else { + // If --no-edit is specified, keep the existing script content + script.script.clone() + }; + + // Update the script content + script.script = script_content; + + // Update the script in the store + script_store.update(script).await?; + + // Rebuild the database to apply changes + script_store.build(script_db).await?; + + println!("Script '{}' updated successfully!", edit.name); + + Ok(()) + } else { + bail!("script '{}' not found", edit.name); + } + } + + async fn handle_delete( + _settings: &Settings, + delete: Delete, + script_store: ScriptStore, + script_db: atuin_scripts::database::Database, + ) -> Result<()> { + // Find the script by name + let script = script_db.get_by_name(&delete.name).await?; + + if let Some(script) = script { + // If not force, confirm deletion + if !delete.force { + println!( + "Are you sure you want to delete script '{}'? [y/N]", + delete.name + ); + let mut input = String::new(); + std::io::stdin().read_line(&mut input)?; + + let input = input.trim().to_lowercase(); + if input != "y" && input != "yes" { + println!("Deletion cancelled"); + return Ok(()); + } + } + + // Delete the script + script_store.delete(script.id).await?; + + // Rebuild the database to apply changes + script_store.build(script_db).await?; + + println!("Script '{}' deleted successfully", delete.name); + Ok(()) + } else { + bail!("script '{}' not found", delete.name); + } + } + + pub async fn run( + self, + settings: &Settings, + store: SqliteStore, + history_db: &impl Database, + ) -> Result<()> { + let host_id = Settings::host_id().expect("failed to get host_id"); + let encryption_key: [u8; 32] = atuin_client::encryption::load_key(settings)?.into(); + + let script_store = ScriptStore::new(store, host_id, encryption_key); + let script_db = + atuin_scripts::database::Database::new(settings.scripts.database_path.clone(), 1.0) + .await?; + + match self { + Self::New(new_script) => { + Self::handle_new_script(settings, new_script, script_store, script_db, history_db) + .await + } + Self::Run(run) => Self::handle_run(settings, run, script_db).await, + Self::List(list) => Self::handle_list(settings, list, script_db).await, + Self::Get(get) => Self::handle_get(settings, get, script_db).await, + Self::Edit(edit) => Self::handle_edit(settings, edit, script_store, script_db).await, + Self::Delete(delete) => { + Self::handle_delete(settings, delete, script_store, script_db).await + } + } + } +} diff --git a/crates/atuin/src/command/client/store/rebuild.rs b/crates/atuin/src/command/client/store/rebuild.rs index e5cea38061c..6fdd3ca4919 100644 --- a/crates/atuin/src/command/client/store/rebuild.rs +++ b/crates/atuin/src/command/client/store/rebuild.rs @@ -1,4 +1,5 @@ use atuin_dotfiles::store::{AliasStore, var::VarStore}; +use atuin_scripts::store::ScriptStore; use clap::Args; use eyre::{Result, bail}; @@ -33,6 +34,10 @@ impl Rebuild { self.rebuild_dotfiles(settings, store.clone()).await?; } + "scripts" => { + self.rebuild_scripts(settings, store.clone()).await?; + } + tag => bail!("unknown tag: {tag}"), } @@ -68,4 +73,17 @@ impl Rebuild { Ok(()) } + + async fn rebuild_scripts(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + let host_id = Settings::host_id().expect("failed to get host_id"); + let script_store = ScriptStore::new(store, host_id, encryption_key); + let database = + atuin_scripts::database::Database::new(settings.scripts.database_path.clone(), 1.0) + .await?; + + script_store.build(database).await?; + + Ok(()) + } } diff --git a/crates/atuin/src/sync.rs b/crates/atuin/src/sync.rs index d0dfb3b4a2c..2d7502e99df 100644 --- a/crates/atuin/src/sync.rs +++ b/crates/atuin/src/sync.rs @@ -1,4 +1,5 @@ use atuin_dotfiles::store::{AliasStore, var::VarStore}; +use atuin_scripts::store::ScriptStore; use eyre::{Context, Result}; use atuin_client::{ @@ -30,11 +31,16 @@ pub async fn build( let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); let alias_store = AliasStore::new(store.clone(), host_id, encryption_key); let var_store = VarStore::new(store.clone(), host_id, encryption_key); + let script_store = ScriptStore::new(store.clone(), host_id, encryption_key); history_store.incremental_build(db, downloaded).await?; alias_store.build().await?; var_store.build().await?; + let script_db = + atuin_scripts::database::Database::new(settings.scripts.database_path.clone(), 1.0).await?; + script_store.build(script_db).await?; + Ok(()) }