diff --git a/Cargo.lock b/Cargo.lock index 03748cc86..c386af287 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,9 +182,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", @@ -850,9 +850,11 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a6ddc50d113188cb707839b8670faabdbab39c052846e2430ea8d47d893b18d" dependencies = [ + "async-trait", "cgroups-rs", "command-fds", "containerd-shim-protos", + "futures", "go-flag", "lazy_static", "libc", @@ -867,8 +869,10 @@ dependencies = [ "serde_json", "sha2", "signal-hook", + "signal-hook-tokio", "thiserror 2.0.11", "time", + "tokio", "which 7.0.1", "windows-sys 0.52.0", ] @@ -886,6 +890,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb8db604974f81d1e350d30f274872f43b45e79203ebb8b1ff714e7b18d24e81" dependencies = [ + "async-trait", "protobuf 3.2.0", "ttrpc", "ttrpc-codegen", @@ -899,6 +904,7 @@ dependencies = [ "containerd-shim-wasm", "log", "serial_test", + "tokio", "wamr-rust-sdk", ] @@ -907,6 +913,7 @@ name = "containerd-shim-wasm" version = "0.9.0" dependencies = [ "anyhow", + "async-trait", "caps", "chrono", "containerd-client", @@ -936,11 +943,14 @@ dependencies = [ "tempfile", "thiserror 2.0.11", "tokio", + "tokio-async-drop", "tokio-stream", + "tokio-util", "toml", "tracing", "tracing-opentelemetry", "tracing-subscriber", + "trait-variant", "ttrpc-codegen", "wasmparser 0.226.0", "wat", @@ -966,6 +976,7 @@ dependencies = [ "libc", "log", "serial_test", + "tokio", "wasmedge-sdk", ] @@ -3066,6 +3077,15 @@ dependencies = [ "libc", ] +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + [[package]] name = "memoffset" version = "0.7.1" @@ -3181,6 +3201,18 @@ dependencies = [ "cc", ] +[[package]] +name = "nix" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" +dependencies = [ + "bitflags 1.3.2", + "cfg-if 1.0.0", + "libc", + "memoffset 0.6.5", +] + [[package]] name = "nix" version = "0.25.1" @@ -4455,7 +4487,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -4475,9 +4507,11 @@ checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-channel", "futures-core", "futures-util", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "http-body-util", @@ -4501,6 +4535,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 1.0.1", + "system-configuration 0.6.1", "tokio", "tokio-native-tls", "tokio-rustls 0.26.0", @@ -5107,6 +5142,18 @@ dependencies = [ "libc", ] +[[package]] +name = "signal-hook-tokio" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213241f76fb1e37e27de3b6aa1b068a2c333233b59cca6634f634b80a27ecf1e" +dependencies = [ + "futures-core", + "libc", + "signal-hook", + "tokio", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -5320,7 +5367,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "system-configuration-sys 0.6.0", ] [[package]] @@ -5333,6 +5391,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "system-interface" version = "0.27.2" @@ -5533,6 +5601,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", @@ -5648,6 +5717,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-vsock" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52a15c15b1bc91f90902347eff163b5b682643aff0c8e972912cca79bd9208dd" +dependencies = [ + "bytes", + "futures", + "libc", + "tokio", + "vsock 0.3.0", +] + [[package]] name = "tokio-vsock" version = "0.6.0" @@ -5658,7 +5740,7 @@ dependencies = [ "futures", "libc", "tokio", - "vsock", + "vsock 0.5.1", ] [[package]] @@ -5944,7 +6026,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tokio-vsock", + "tokio-vsock 0.6.0", "trapeze-codegen", "trapeze-macros", "windows-sys 0.59.0", @@ -5985,8 +6067,10 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c580c498a547b4c083ec758be543e11a0772e03013aef4cdb1fbe77c8b62cae" dependencies = [ + "async-trait", "byteorder", "crossbeam", + "futures", "home", "libc", "log", @@ -5994,6 +6078,8 @@ dependencies = [ "protobuf 3.2.0", "protobuf-codegen 3.2.0", "thiserror 1.0.69", + "tokio", + "tokio-vsock 0.4.0", "windows-sys 0.48.0", ] @@ -6234,6 +6320,16 @@ dependencies = [ "virtual-mio", ] +[[package]] +name = "vsock" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8e1df0bf1e1b28095c24564d1b90acae64ca69b097ed73896e342fa6649c57" +dependencies = [ + "libc", + "nix 0.24.3", +] + [[package]] name = "vsock" version = "0.5.1" diff --git a/crates/containerd-shim-wamr/Cargo.toml b/crates/containerd-shim-wamr/Cargo.toml index 97e2d4f37..0e3404248 100644 --- a/crates/containerd-shim-wamr/Cargo.toml +++ b/crates/containerd-shim-wamr/Cargo.toml @@ -15,6 +15,7 @@ wamr-rust-sdk = { git = "https://github.com/bytecodealliance/wamr-rust-sdk", tag [dev-dependencies] containerd-shim-wasm = { workspace = true, features = ["testing"] } serial_test = { workspace = true } +tokio = { workspace = true } [[bin]] name = "containerd-shim-wamr-v1" diff --git a/crates/containerd-shim-wamr/src/tests.rs b/crates/containerd-shim-wamr/src/tests.rs index 040323f68..7eff9bf7f 100644 --- a/crates/containerd-shim-wamr/src/tests.rs +++ b/crates/containerd-shim-wamr/src/tests.rs @@ -7,21 +7,28 @@ use serial_test::serial; use crate::instance::WamrInstance as WasiInstance; -#[test] +#[tokio::test] #[serial] -fn test_delete_after_create() -> anyhow::Result<()> { - WasiTest::::builder()?.build()?.delete()?; +async fn test_delete_after_create() -> anyhow::Result<()> { + WasiTest::::builder()? + .build() + .await? + .delete() + .await?; Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world() -> anyhow::Result<()> { +async fn test_hello_world() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -29,42 +36,54 @@ fn test_hello_world() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world_oci() -> anyhow::Result<()> { +async fn test_hello_world_oci() -> anyhow::Result<()> { let (builder, _oci_cleanup) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? .as_oci_image(None, None)?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_unreachable() -> anyhow::Result<()> { +async fn test_unreachable() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(UNREACHABLE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_ne!(exit_code, 0); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_seccomp() -> anyhow::Result<()> { +async fn test_seccomp() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(SECCOMP)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout.trim(), "current working dir: /"); @@ -72,45 +91,54 @@ fn test_seccomp() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_has_default_devices() -> anyhow::Result<()> { +async fn test_has_default_devices() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(HAS_DEFAULT_DEVICES)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); Ok(()) } -#[test] +#[tokio::test] #[ignore = "disabled because the WAMR SDK doesn't expose exit code yet"] // See https://github.com/containerd/runwasi/pull/716#discussion_r1827086060 -fn test_exit_code() -> anyhow::Result<()> { +async fn test_exit_code() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(EXIT_CODE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 42); Ok(()) } -#[test] +#[tokio::test] #[ignore] // See https://github.com/containerd/runwasi/pull/716#issuecomment-2458200081 -fn test_custom_entrypoint() -> anyhow::Result<()> { +async fn test_custom_entrypoint() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_start_fn("foo") .with_wasm(CUSTOM_ENTRYPOINT)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); diff --git a/crates/containerd-shim-wasm/Cargo.toml b/crates/containerd-shim-wasm/Cargo.toml index 61e191639..b539edd9d 100644 --- a/crates/containerd-shim-wasm/Cargo.toml +++ b/crates/containerd-shim-wasm/Cargo.toml @@ -14,7 +14,7 @@ doctest = false [dependencies] anyhow = { workspace = true } chrono = { workspace = true } -containerd-shim = { workspace = true } +containerd-shim = { workspace = true, features = ["async"] } containerd-shim-wasm-test-modules = { workspace = true, optional = true } oci-tar-builder = { workspace = true, optional = true } env_logger = { workspace = true, optional = true } @@ -36,13 +36,17 @@ sha256 = { workspace = true } serde_bytes = "0.11" prost = "0.13" toml = "0.8" +async-trait = "0.1.86" +trait-variant = "0.1" +tokio-async-drop = "0.1.0" +tokio-util = "0.7.13" # tracing # note: it's important to keep the version of tracing in sync with tracing-subscriber tracing = { workspace = true, optional = true } # does not include `tracing-log` feature due to https://github.com/spinkube/containerd-shim-spin/issues/61 tracing-subscriber = { version = "0.3", default-features = false, features = [ - "smallvec", # Enables performance optimizations + "smallvec", # Enables performance optimizations "fmt", "ansi", "std", @@ -52,7 +56,9 @@ tracing-subscriber = { version = "0.3", default-features = false, features = [ ], optional = true } # opentelemetry -opentelemetry = { version = "0.23", features = ["trace"], optional = true, default-features = false} +opentelemetry = { version = "0.23", features = [ + "trace", +], optional = true, default-features = false } opentelemetry-otlp = { version = "0.16.0", default-features = false, features = [ "grpc-tonic", "http-proto", @@ -115,4 +121,4 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"] [package.metadata.cargo-machete] # used as part of a derive macro -ignored = ["serde_bytes"] \ No newline at end of file +ignored = ["serde_bytes"] diff --git a/crates/containerd-shim-wasm/src/container/tests.rs b/crates/containerd-shim-wasm/src/container/tests.rs index 669bc88df..bf7126986 100644 --- a/crates/containerd-shim-wasm/src/container/tests.rs +++ b/crates/containerd-shim-wasm/src/container/tests.rs @@ -21,16 +21,17 @@ impl Engine for EngineFailingValidation { type InstanceFailingValidation = Instance; -#[test] +#[tokio::test] #[cfg(unix)] // not yet implemented on Windows -fn test_validation_error() -> anyhow::Result<()> { +async fn test_validation_error() -> anyhow::Result<()> { // A validation error should fail when creating the container // as opposed to failing when starting it. let result = WasiTest::::builder()? .with_start_fn("foo") .with_wasm("/invalid_entrypoint.wasm")? - .build(); + .build() + .await; assert!(result.is_err()); diff --git a/crates/containerd-shim-wasm/src/sandbox/cli.rs b/crates/containerd-shim-wasm/src/sandbox/cli.rs index 6690f161b..2fefa835f 100644 --- a/crates/containerd-shim-wasm/src/sandbox/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/cli.rs @@ -80,6 +80,8 @@ //! use std::path::PathBuf; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, LazyLock}; use containerd_shim::{Config, parse, run}; @@ -91,6 +93,7 @@ pub mod r#impl { pub use git_version::git_version; } +use super::async_utils::AmbientRuntime as _; pub use crate::{revision, version}; /// Get the crate version from Cargo.toml. @@ -116,9 +119,10 @@ macro_rules! revision { } #[cfg(target_os = "linux")] -fn get_mem(pid: u32) -> (usize, usize) { +fn get_stats(pid: u32) -> (usize, usize, usize) { let mut rss = 0; let mut total = 0; + let mut threads = 0; for line in std::fs::read_to_string(format!("/proc/{pid}/status")) .unwrap() .lines() @@ -135,19 +139,48 @@ fn get_mem(pid: u32) -> (usize, usize) { if let Some(rest) = rest.strip_suffix("kB") { rss = rest.trim().parse().unwrap_or(0); } + } else if let Some(rest) = line.strip_prefix("Threads:") { + threads = rest.trim().parse().unwrap_or(0); } } - (rss, total) + (rss, total, threads) } #[cfg(target_os = "linux")] -fn log_mem() { +fn monitor_treads() -> usize { + use std::sync::atomic::Ordering::SeqCst; + use std::time::Duration; + + use tokio::time::sleep; + + static NUM_THREADS: LazyLock> = LazyLock::new(|| { + let pid = std::process::id(); + let num_threads = Arc::new(AtomicUsize::new(0)); + let n = num_threads.clone(); + async move { + loop { + let (_, _, threads) = get_stats(pid); + n.fetch_max(threads, SeqCst); + sleep(Duration::from_millis(10)).await; + } + } + .spawn(); + num_threads + }); + NUM_THREADS.load(SeqCst) +} + +#[cfg(target_os = "linux")] +fn log_stats() { let pid = std::process::id(); - let (rss, tot) = get_mem(pid); + let (rss, tot, _) = get_stats(pid); log::info!("Shim peak memory usage was: peak resident set {rss} kB, peak total {tot} kB"); + let threads = monitor_treads(); + log::info!("Shim peak number of threads was {threads}"); + let pid = zygote::Zygote::global().run(|_| std::process::id(), ()); - let (rss, tot) = get_mem(pid); + let (rss, tot, _) = get_stats(pid); log::info!("Zygote peak memory usage was: peak resident set {rss} kB, peak total {tot} kB"); } @@ -169,33 +202,33 @@ pub fn shim_main<'a, I>( #[cfg(unix)] zygote::Zygote::init(); - #[cfg(feature = "opentelemetry")] - if otel_traces_enabled() { - // opentelemetry uses tokio, so we need to initialize a runtime - use tokio::runtime::Runtime; - let rt = Runtime::new().unwrap(); - rt.block_on(async { + #[cfg(unix)] + monitor_treads(); + + async { + #[cfg(feature = "opentelemetry")] + if otel_traces_enabled() { + // opentelemetry uses tokio, so we need to initialize a runtime let otlp_config = OtlpConfig::build_from_env().expect("Failed to build OtelConfig."); let _guard = otlp_config .init() .expect("Failed to initialize OpenTelemetry."); - shim_main_inner::(name, version, revision, shim_version, config); - }); - } else { - shim_main_inner::(name, version, revision, shim_version, config); - } + shim_main_inner::(name, version, revision, shim_version, config).await; + } else { + shim_main_inner::(name, version, revision, shim_version, config).await; + }; - #[cfg(not(feature = "opentelemetry"))] - { - shim_main_inner::(name, version, revision, shim_version, config); + #[cfg(not(feature = "opentelemetry"))] + shim_main_inner::(name, version, revision, shim_version, config).await; } + .block_on(); #[cfg(target_os = "linux")] - log_mem(); + log_stats(); } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] -fn shim_main_inner<'a, I>( +async fn shim_main_inner<'a, I>( name: &str, version: &str, revision: impl Into> + std::fmt::Debug, @@ -241,5 +274,5 @@ fn shim_main_inner<'a, I>( let lower_name = name.to_lowercase(); let shim_id = format!("io.containerd.{lower_name}.{shim_version}"); - run::>(&shim_id, config); + run::>(&shim_id, config).await; } diff --git a/crates/containerd-shim-wasm/src/sandbox/instance.rs b/crates/containerd-shim-wasm/src/sandbox/instance.rs index ec8c0a681..2d34ecf38 100644 --- a/crates/containerd-shim-wasm/src/sandbox/instance.rs +++ b/crates/containerd-shim-wasm/src/sandbox/instance.rs @@ -32,28 +32,29 @@ pub struct InstanceConfig { /// Instance is a trait that gets implemented by consumers of this library. /// This trait requires that any type implementing it is `'static`, similar to `std::any::Any`. /// This means that the type cannot contain a non-`'static` reference. +#[trait_variant::make(Send)] pub trait Instance: 'static { /// The WASI engine type type Engine: Send + Sync + Clone; /// Create a new instance - fn new(id: String, cfg: &InstanceConfig) -> Result + async fn new(id: String, cfg: &InstanceConfig) -> Result where Self: Sized; /// Start the instance /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. - fn start(&self) -> Result; + async fn start(&self) -> Result; /// Send a signal to the instance - fn kill(&self, signal: u32) -> Result<(), Error>; + async fn kill(&self, signal: u32) -> Result<(), Error>; /// Delete any reference to the instance /// This is called after the instance has exited. - fn delete(&self) -> Result<(), Error>; + async fn delete(&self) -> Result<(), Error>; /// Waits for the instance to finish and returns its exit code /// This is an async call. - fn wait(&self) -> impl Future)> + Send; + async fn wait(&self) -> (u32, DateTime); } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs index a85aeba7a..50e7cfa69 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs @@ -2,10 +2,11 @@ use std::env::current_dir; use std::fmt::Debug; use std::sync::Arc; +use async_trait::async_trait; use chrono::Utc; use containerd_shim::error::Error as ShimError; use containerd_shim::publisher::RemotePublisher; -use containerd_shim::util::write_address; +use containerd_shim::util::write_str_to_file; use containerd_shim::{self as shim, ExitSignal, api}; use oci_spec::runtime::Spec; use shim::Flags; @@ -37,6 +38,7 @@ where } } +#[async_trait] impl shim::Shim for Cli where I: Instance + Sync + Send, @@ -45,7 +47,7 @@ where type T = Local; #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn new(_runtime_id: &str, args: &Flags, _config: &mut shim::Config) -> Self { + async fn new(_runtime_id: &str, args: &Flags, _config: &mut shim::Config) -> Self { Cli { engine: Default::default(), namespace: args.namespace.to_string(), @@ -56,7 +58,7 @@ where } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn start_shim(&mut self, opts: containerd_shim::StartOpts) -> shim::Result { + async fn start_shim(&mut self, opts: containerd_shim::StartOpts) -> shim::Result { let dir = current_dir().map_err(|err| ShimError::Other(err.to_string()))?; let spec = Spec::load(dir.join("config.json")).map_err(|err| { shim::Error::InvalidArgument(format!("error loading runtime spec: {}", err)) @@ -69,23 +71,23 @@ where .and_then(|a| a.get("io.kubernetes.cri.sandbox-id")) .unwrap_or(&id); - let (_child, address) = shim::spawn(opts, grouping, vec![])?; + let address = shim::spawn(opts, grouping, vec![]).await?; - write_address(&address)?; + write_str_to_file("address", &address).await?; Ok(address) } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn wait(&mut self) { - self.exit.wait(); + async fn wait(&mut self) { + self.exit.wait().await; } #[cfg_attr( feature = "tracing", tracing::instrument(skip(publisher), level = "Info") )] - fn create_task_service(&self, publisher: RemotePublisher) -> Self::T { + async fn create_task_service(&self, publisher: RemotePublisher) -> Self::T { let events = RemoteEventSender::new(&self.namespace, publisher); let exit = self.exit.clone(); let engine = self.engine.clone(); @@ -99,7 +101,7 @@ where } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn delete_shim(&mut self) -> shim::Result { + async fn delete_shim(&mut self) -> shim::Result { Ok(api::DeleteResponse { exit_status: 137, exited_at: Some(Utc::now().to_timestamp()).into(), diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/events.rs b/crates/containerd-shim-wasm/src/sandbox/shim/events.rs index c727fe051..ef7e093e4 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/events.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/events.rs @@ -7,7 +7,7 @@ use log::warn; use protobuf::well_known_types::timestamp::Timestamp; pub trait EventSender: Clone + Send + Sync + 'static { - fn send(&self, event: impl Event); + fn send(&self, event: impl Event) -> impl Future + Send; } #[derive(Clone)] @@ -33,12 +33,13 @@ impl RemoteEventSender { } impl EventSender for RemoteEventSender { - fn send(&self, event: impl Event) { + async fn send(&self, event: impl Event) { let topic = event.topic(); let event = Box::new(event); let publisher = &self.inner.publisher; - if let Err(err) = - publisher.publish(Default::default(), &topic, &self.inner.namespace, event) + if let Err(err) = publisher + .publish(Default::default(), &topic, &self.inner.namespace, event) + .await { warn!("failed to publish event, topic: {}: {}", &topic, err) } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs b/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs index 46c973af3..1d1723890 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/instance_data.rs @@ -1,6 +1,7 @@ -use std::sync::{OnceLock, RwLock}; +use std::sync::OnceLock; use chrono::{DateTime, Utc}; +use tokio::sync::RwLock; use crate::sandbox::shim::task_state::TaskState; use crate::sandbox::{Instance, InstanceConfig, Result}; @@ -14,9 +15,12 @@ pub(super) struct InstanceData { impl InstanceData { #[cfg_attr(feature = "tracing", tracing::instrument(level = "Debug"))] - pub fn new(id: impl AsRef + std::fmt::Debug, config: InstanceConfig) -> Result { + pub async fn new( + id: impl AsRef + std::fmt::Debug, + config: InstanceConfig, + ) -> Result { let id = id.as_ref().to_string(); - let instance = T::new(id, &config)?; + let instance = T::new(id, &config).await?; Ok(Self { instance, config, @@ -31,11 +35,11 @@ impl InstanceData { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn start(&self) -> Result { - let mut s = self.state.write().unwrap(); + pub async fn start(&self) -> Result { + let mut s = self.state.write().await; s.start()?; - let res = self.instance.start(); + let res = self.instance.start().await; // These state transitions are always `Ok(())` because // we hold the lock since `s.start()` @@ -51,19 +55,19 @@ impl InstanceData { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn kill(&self, signal: u32) -> Result<()> { - let mut s = self.state.write().unwrap(); + pub async fn kill(&self, signal: u32) -> Result<()> { + let mut s = self.state.write().await; s.kill()?; - self.instance.kill(signal) + self.instance.kill(signal).await } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub fn delete(&self) -> Result<()> { - let mut s = self.state.write().unwrap(); + pub async fn delete(&self) -> Result<()> { + let mut s = self.state.write().await; s.delete()?; - let res = self.instance.delete(); + let res = self.instance.delete().await; if res.is_err() { // Always `Ok(())` because we hold the lock since `s.delete()` @@ -76,7 +80,7 @@ impl InstanceData { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] pub async fn wait(&self) -> (u32, DateTime) { let res = self.instance.wait().await; - let mut s = self.state.write().unwrap(); + let mut s = self.state.write().await; *s = TaskState::Exited; res } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs index 0682ea3d1..4eca73af1 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs @@ -2,12 +2,13 @@ use std::collections::HashMap; use std::fs::create_dir_all; use std::ops::Not; use std::path::Path; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::thread; #[cfg(feature = "opentelemetry")] use std::time::Duration; use anyhow::ensure; +use async_trait::async_trait; use containerd_shim::api::{ ConnectRequest, ConnectResponse, CreateTaskRequest, CreateTaskResponse, DeleteRequest, Empty, KillRequest, ShutdownRequest, StartRequest, StartResponse, StateRequest, StateResponse, @@ -15,16 +16,16 @@ use containerd_shim::api::{ }; use containerd_shim::error::Error as ShimError; use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, TaskIO, TaskStart}; -use containerd_shim::protos::shim::shim_ttrpc::Task; use containerd_shim::protos::types::task::Status; use containerd_shim::util::IntoOption; -use containerd_shim::{DeleteResponse, ExitSignal, TtrpcContext, TtrpcResult}; +use containerd_shim::{DeleteResponse, ExitSignal, Task, TtrpcContext, TtrpcResult}; use futures::FutureExt as _; use log::debug; use oci_spec::runtime::Spec; use prost::Message; use protobuf::well_known_types::any::Any; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; #[cfg(feature = "opentelemetry")] use tracing_opentelemetry::OpenTelemetrySpanExt as _; @@ -122,26 +123,26 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - pub(super) fn get_instance(&self, id: &str) -> Result>> { - let instance = self.instances.read().unwrap().get(id).cloned(); + pub(super) async fn get_instance(&self, id: &str) -> Result>> { + let instance = self.instances.read().await.get(id).cloned(); instance.ok_or_else(|| Error::NotFound(id.to_string())) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn has_instance(&self, id: &str) -> bool { - self.instances.read().unwrap().contains_key(id) + async fn has_instance(&self, id: &str) -> bool { + self.instances.read().await.contains_key(id) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn is_empty(&self) -> bool { - self.instances.read().unwrap().is_empty() + async fn is_empty(&self) -> bool { + self.instances.read().await.is_empty() } } // These are the same functions as in Task, but without the TtrcpContext, which is useful for testing impl Local { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_create(&self, req: CreateTaskRequest) -> Result { + async fn task_create(&self, req: CreateTaskRequest) -> Result { let config = Config::get_from_options(req.options.as_ref()) .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?; @@ -155,7 +156,7 @@ impl Local { )); } - if self.has_instance(&req.id) { + if self.has_instance(&req.id).await { return Err(Error::AlreadyExists(req.id)); } @@ -200,11 +201,11 @@ impl Local { }; // Check if this is a cri container - let instance = InstanceData::new(req.id(), cfg)?; + let instance = InstanceData::new(req.id(), cfg).await?; self.instances .write() - .unwrap() + .await .insert(req.id().to_string(), Arc::new(instance)); self.events.send(TaskCreate { @@ -219,7 +220,7 @@ impl Local { }) .into(), ..Default::default() - }); + }).await; debug!("create done"); @@ -234,19 +235,19 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_start(&self, req: StartRequest) -> Result { + async fn task_start(&self, req: StartRequest) -> Result { if req.exec_id().is_empty().not() { return Err(ShimError::Unimplemented("exec is not supported".to_string()).into()); } - let i = self.get_instance(req.id())?; - let pid = i.start()?; + let i = self.get_instance(req.id()).await?; + let pid = i.start().await?; self.events.send(TaskStart { container_id: req.id().into(), pid, ..Default::default() - }); + }).await; let events = self.events.clone(); @@ -261,7 +262,7 @@ impl Local { pid, id, ..Default::default() - }); + }).await; } .spawn(); @@ -274,29 +275,32 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_kill(&self, req: KillRequest) -> Result { + async fn task_kill(&self, req: KillRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - self.get_instance(req.id())?.kill(req.signal())?; + self.get_instance(req.id()) + .await? + .kill(req.signal()) + .await?; Ok(Empty::new()) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_delete(&self, req: DeleteRequest) -> Result { + async fn task_delete(&self, req: DeleteRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; - i.delete()?; + i.delete().await?; let pid = i.pid().unwrap_or_default(); let (exit_code, timestamp) = i.wait().now_or_never().unzip(); let timestamp = timestamp.map(ToTimestamp::to_timestamp); - self.instances.write().unwrap().remove(req.id()); + self.instances.write().await.remove(req.id()); self.events.send(TaskDelete { container_id: req.id().into(), @@ -304,7 +308,7 @@ impl Local { exit_status: exit_code.unwrap_or_default(), exited_at: timestamp.clone().into(), ..Default::default() - }); + }).await; Ok(DeleteResponse { pid, @@ -315,13 +319,13 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_wait(&self, req: WaitRequest) -> Result { + async fn task_wait(&self, req: WaitRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; - let (exit_code, timestamp) = i.wait().block_on(); + let i = self.get_instance(req.id()).await?; + let (exit_code, timestamp) = i.wait().await; debug!("wait finishes"); Ok(WaitResponse { @@ -332,12 +336,12 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_state(&self, req: StateRequest) -> Result { + async fn task_state(&self, req: StateRequest) -> Result { if !req.exec_id().is_empty() { return Err(Error::InvalidArgument("exec is not supported".to_string())); } - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; let pid = i.pid(); let (exit_code, timestamp) = i.wait().now_or_never().unzip(); let timestamp = timestamp.map(ToTimestamp::to_timestamp); @@ -364,8 +368,8 @@ impl Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))] - fn task_stats(&self, req: StatsRequest) -> Result { - let i = self.get_instance(req.id())?; + async fn task_stats(&self, req: StatsRequest) -> Result { + let i = self.get_instance(req.id()).await?; let pid = i .pid() .ok_or_else(|| Error::InvalidArgument("task is not running".to_string()))?; @@ -379,9 +383,10 @@ impl Local { } } +#[async_trait] impl Task for Local { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn create( + async fn create( &self, _ctx: &TtrpcContext, req: CreateTaskRequest, @@ -391,41 +396,41 @@ impl Task for Local { #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_create(req)?) + Ok(self.task_create(req).await?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult { + async fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult { debug!("start: {:?}", req); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_start(req)?) + Ok(self.task_start(req).await?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult { + async fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult { debug!("kill: {:?}", req); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_kill(req)?) + Ok(self.task_kill(req).await?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult { + async fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult { debug!("delete: {:?}", req); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_delete(req)?) + Ok(self.task_delete(req).await?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult { + async fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult { debug!("wait: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -447,7 +452,7 @@ impl Task for Local { } } }); - let result = self.task_wait(req)?; + let result = self.task_wait(req).await?; tx.send(()).unwrap(); Ok(result) } @@ -459,13 +464,17 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn connect(&self, _ctx: &TtrpcContext, req: ConnectRequest) -> TtrpcResult { + async fn connect( + &self, + _ctx: &TtrpcContext, + req: ConnectRequest, + ) -> TtrpcResult { debug!("connect: {:?}", req); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - let i = self.get_instance(req.id())?; + let i = self.get_instance(req.id()).await?; let shim_pid = std::process::id(); let task_pid = i.pid().unwrap_or_default(); Ok(ConnectResponse { @@ -476,35 +485,35 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult { + async fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult { debug!("state: {:?}", req); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_state(req)?) + Ok(self.task_state(req).await?) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult { + async fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult { debug!("shutdown"); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - if self.is_empty() { + if self.is_empty().await { self.exit.signal(); } Ok(Empty::new()) } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult { + async fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult { debug!("stats: {:?}", req); #[cfg(feature = "opentelemetry")] tracing::Span::current().set_parent(extract_context(&_ctx.metadata)); - Ok(self.task_stats(req)?) + Ok(self.task_stats(req).await?) } } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs index 99ed5e3c2..d8ec9e594 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local/tests.rs @@ -1,6 +1,4 @@ use std::fs::{File, create_dir}; -use std::sync::mpsc::{Sender, channel}; -use std::thread; use std::time::Duration; use anyhow::Context; @@ -10,6 +8,8 @@ use containerd_shim::event::Event; use protobuf::{MessageDyn, SpecialFields}; use serde_json as json; use tempfile::tempdir; +use tokio::sync::mpsc::{UnboundedSender as Sender, unbounded_channel as channel}; +use tokio_async_drop::tokio_async_drop; use super::*; use crate::sandbox::shim::events::EventSender; @@ -24,19 +24,19 @@ pub struct InstanceStub { impl Instance for InstanceStub { type Engine = (); - fn new(_id: String, _cfg: &InstanceConfig) -> Result { + async fn new(_id: String, _cfg: &InstanceConfig) -> Result { Ok(InstanceStub { exit_code: WaitableCell::new(), }) } - fn start(&self) -> Result { + async fn start(&self) -> Result { Ok(std::process::id()) } - fn kill(&self, _signal: u32) -> Result<(), Error> { + async fn kill(&self, _signal: u32) -> Result<(), Error> { let _ = self.exit_code.set((1, Utc::now())); Ok(()) } - fn delete(&self) -> Result<(), Error> { + async fn delete(&self) -> Result<(), Error> { Ok(()) } async fn wait(&self) -> (u32, DateTime) { @@ -55,22 +55,20 @@ impl LocalWithDestructor { } impl EventSender for Sender<(String, Box)> { - fn send(&self, event: impl Event) { + async fn send(&self, event: impl Event) { let _ = self.send((event.topic(), Box::new(event))); } } impl Drop for LocalWithDestructor { fn drop(&mut self) { - self.local - .instances - .write() - .unwrap() - .iter() - .for_each(|(_, v)| { - let _ = v.kill(9); - v.delete().unwrap(); - }); + tokio_async_drop!({ + let instances = self.local.instances.write().await; + for (_, instance) in instances.iter() { + let _ = instance.kill(9); + let _ = instance.delete().await; + } + }) } } @@ -98,8 +96,8 @@ fn create_bundle(dir: &std::path::Path, spec: Option) -> Result<()> { Ok(()) } -#[test] -fn test_delete_after_create() { +#[tokio::test] +async fn test_delete_after_create() -> anyhow::Result<()> { let dir = tempdir().unwrap(); let id = "test-delete-after-create"; create_bundle(dir.path(), None).unwrap(); @@ -120,18 +118,20 @@ fn test_delete_after_create() { bundle: dir.path().to_str().unwrap().to_string(), ..Default::default() }) - .unwrap(); + .await?; local .task_delete(DeleteRequest { id: id.to_string(), ..Default::default() }) - .unwrap(); + .await?; + + Ok(()) } -#[test] -fn test_cri_task() -> Result<()> { +#[tokio::test] +async fn test_cri_task() -> Result<()> { // Currently the relationship between the "base" container and the "instances" are pretty weak. // When a cri sandbox is specified we just assume it's the sandbox container and treat it as such by not actually running the code (which is going to be wasm). let (etx, _erx) = channel(); @@ -151,39 +151,49 @@ fn test_cri_task() -> Result<()> { let sandbox_id = "test-cri-task".to_string(); create_bundle(dir, Some(with_cri_sandbox(None, sandbox_id.clone())))?; - local.task_create(CreateTaskRequest { - id: "testbase".to_string(), - bundle: dir.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "testbase".to_string(), + bundle: dir.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); // make sure that the instance exists - let _i = local.get_instance("testbase")?; + let _i = local.get_instance("testbase").await?; - local.task_start(StartRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); let ll = local.clone(); - let (base_tx, base_rx) = channel(); - thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "testbase".to_string(), - ..Default::default() - }); + let (base_tx, mut base_rx) = channel(); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await; base_tx.send(resp).unwrap(); }); base_rx.try_recv().unwrap_err(); @@ -192,72 +202,96 @@ fn test_cri_task() -> Result<()> { let dir2 = temp2.path(); create_bundle(dir2, Some(with_cri_sandbox(None, sandbox_id)))?; - local.task_create(CreateTaskRequest { - id: "testinstance".to_string(), - bundle: dir2.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "testinstance".to_string(), + bundle: dir2.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); // make sure that the instance exists - let _i = local.get_instance("testinstance")?; + let _i = local.get_instance("testinstance").await?; - local.task_start(StartRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - let stats = local.task_stats(StatsRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let stats = local + .task_stats(StatsRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert!(stats.has_stats()); let ll = local.clone(); - let (instance_tx, instance_rx) = channel(); - std::thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "testinstance".to_string(), - ..Default::default() - }); + let (instance_tx, mut instance_rx) = channel(); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await; instance_tx.send(resp).unwrap(); }); instance_rx.try_recv().unwrap_err(); - local.task_kill(KillRequest { - id: "testinstance".to_string(), - signal: 9, - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "testinstance".to_string(), + signal: 9, + ..Default::default() + }) + .await?; - instance_rx.recv_timeout(Duration::from_secs(50)).unwrap()?; + instance_rx + .recv() + .with_timeout(Duration::from_secs(50)) + .await + .flatten() + .unwrap()?; - let state = local.task_state(StateRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "testinstance".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "testinstance".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "testinstance".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} @@ -265,34 +299,48 @@ fn test_cri_task() -> Result<()> { } base_rx.try_recv().unwrap_err(); - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - local.task_kill(KillRequest { - id: "testbase".to_string(), - signal: 9, - ..Default::default() - })?; - - base_rx.recv_timeout(Duration::from_secs(5)).unwrap()?; - let state = local.task_state(StateRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "testbase".to_string(), + signal: 9, + ..Default::default() + }) + .await?; + + base_rx + .recv() + .with_timeout(Duration::from_secs(5)) + .await + .flatten() + .unwrap()?; + let state = local + .task_state(StateRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "testbase".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "testbase".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "testbase".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} @@ -302,8 +350,8 @@ fn test_cri_task() -> Result<()> { Ok(()) } -#[test] -fn test_task_lifecycle() -> Result<()> { +#[tokio::test] +async fn test_task_lifecycle() -> Result<()> { let (etx, _erx) = channel(); // TODO: check events let exit_signal = Arc::new(ExitSignal::default()); let local = Arc::new(Local::::new( @@ -325,17 +373,20 @@ fn test_task_lifecycle() -> Result<()> { id: "test".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} e => return Err(e), } - local.task_create(CreateTaskRequest { - id: "test".to_string(), - bundle: dir.to_str().unwrap().to_string(), - ..Default::default() - })?; + local + .task_create(CreateTaskRequest { + id: "test".to_string(), + bundle: dir.to_str().unwrap().to_string(), + ..Default::default() + }) + .await?; match local .task_create(CreateTaskRequest { @@ -343,73 +394,95 @@ fn test_task_lifecycle() -> Result<()> { bundle: dir.to_str().unwrap().to_string(), ..Default::default() }) + .await .unwrap_err() { Error::AlreadyExists(_) => {} e => return Err(e), } - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::CREATED); - local.task_start(StartRequest { - id: "test".to_string(), - ..Default::default() - })?; + local + .task_start(StartRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::RUNNING); - let (tx, rx) = channel(); + let (tx, mut rx) = channel(); let ll = local.clone(); - thread::spawn(move || { - let resp = ll.task_wait(WaitRequest { - id: "test".to_string(), - ..Default::default() - }); + tokio::spawn(async move { + let resp = ll + .task_wait(WaitRequest { + id: "test".to_string(), + ..Default::default() + }) + .await; tx.send(resp).unwrap(); }); rx.try_recv().unwrap_err(); - let res = local.task_stats(StatsRequest { - id: "test".to_string(), - ..Default::default() - })?; + let res = local + .task_stats(StatsRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert!(res.has_stats()); - local.task_kill(KillRequest { - id: "test".to_string(), - signal: 9, - ..Default::default() - })?; + local + .task_kill(KillRequest { + id: "test".to_string(), + signal: 9, + ..Default::default() + }) + .await?; - rx.recv_timeout(Duration::from_secs(5)).unwrap()?; + rx.recv() + .with_timeout(Duration::from_secs(5)) + .await + .flatten() + .unwrap()?; - let state = local.task_state(StateRequest { - id: "test".to_string(), - ..Default::default() - })?; + let state = local + .task_state(StateRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; assert_eq!(state.status(), Status::STOPPED); - local.task_delete(DeleteRequest { - id: "test".to_string(), - ..Default::default() - })?; + local + .task_delete(DeleteRequest { + id: "test".to_string(), + ..Default::default() + }) + .await?; match local .task_state(StateRequest { id: "test".to_string(), ..Default::default() }) + .await .unwrap_err() { Error::NotFound(_) => {} diff --git a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs index 47f64f757..58aa9b646 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs @@ -32,11 +32,11 @@ impl SandboxInstance for Instance { type Engine = E; #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn new(id: String, cfg: &InstanceConfig) -> Result { + async fn new(id: String, cfg: &InstanceConfig) -> Result { // check if container is OCI image with wasm layers and attempt to read the module - let (modules, platform) = containerd::Client::connect(&cfg.containerd_address, &cfg.namespace).block_on()? + let (modules, platform) = containerd::Client::connect(&cfg.containerd_address, &cfg.namespace).await? .load_modules(&id, &E::default()) - .block_on() + .await .unwrap_or_else(|e| { log::warn!("Error obtaining wasm layers for container {id}. Will attempt to use files inside container image. Error: {e}"); (vec![], Platform::default()) @@ -85,7 +85,7 @@ impl SandboxInstance for Instance { /// The returned value should be a unique ID (such as a PID) for the instance. /// Nothing internally should be using this ID, but it is returned to containerd where a user may want to use it. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn start(&self) -> Result { + async fn start(&self) -> Result { log::info!("starting instance: {}", self.id); // make sure we have an exit code by the time we finish (even if there's a panic) let guard = self.exit_code.clone().set_guard_with(|| (137, Utc::now())); @@ -95,7 +95,7 @@ impl SandboxInstance for Instance { // Use a pidfd FD so that we can wait for the process to exit asynchronously. // This should be created BEFORE calling container.start() to ensure we never // miss the SIGCHLD event. - let pidfd = PidFd::new(pid)?; + let pidfd = PidFd::new(pid).await?; self.container.start()?; @@ -125,7 +125,7 @@ impl SandboxInstance for Instance { /// Send a signal to the instance #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn kill(&self, signal: u32) -> Result<(), SandboxError> { + async fn kill(&self, signal: u32) -> Result<(), SandboxError> { log::info!("sending signal {signal} to instance: {}", self.id); self.container.kill(signal)?; Ok(()) @@ -134,7 +134,7 @@ impl SandboxInstance for Instance { /// Delete any reference to the instance /// This is called after the instance has exited. #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn delete(&self) -> Result<(), SandboxError> { + async fn delete(&self) -> Result<(), SandboxError> { log::info!("deleting instance: {}", self.id); self.container.delete()?; Ok(()) diff --git a/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs b/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs index e63ac1910..050f90ed4 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs @@ -8,6 +8,8 @@ use nix::sys::wait::{Id, WaitPidFlag, WaitStatus, waitid}; use nix::unistd::Pid; use tokio::io::unix::AsyncFd; +use crate::sandbox::async_utils::AmbientRuntime; + pub(super) struct PidFd { fd: OwnedFd, pid: pid_t, @@ -15,10 +17,10 @@ pub(super) struct PidFd { } impl PidFd { - pub(super) fn new(pid: impl Into) -> anyhow::Result { + pub(super) async fn new(pid: impl Into) -> anyhow::Result { use libc::{PIDFD_NONBLOCK, SYS_pidfd_open, syscall}; let pid = pid.into(); - let subs = monitor_subscribe(Topic::Pid)?; + let subs = monitor_subscribe(Topic::Pid).await?; let pidfd = unsafe { syscall(SYS_pidfd_open, pid, PIDFD_NONBLOCK) }; if pidfd == -1 { return Err(std::io::Error::last_os_error().into()); @@ -58,18 +60,19 @@ impl PidFd { } } -pub async fn try_wait_pid(pid: i32, s: Subscription) -> Result { - tokio::task::spawn_blocking(move || { - while let Ok(ExitEvent { subject, exit_code }) = s.rx.recv_timeout(Duration::from_secs(2)) { - let Subject::Pid(p) = subject else { - continue; - }; - if pid == p { - return Ok(exit_code); - } +pub async fn try_wait_pid(pid: i32, mut s: Subscription) -> Result { + while let Some(ExitEvent { subject, exit_code }) = + s.rx.recv() + .with_timeout(Duration::from_secs(2)) + .await + .flatten() + { + let Subject::Pid(p) = subject else { + continue; + }; + if pid == p { + return Ok(exit_code); } - Err(Errno::ECHILD) - }) - .await - .map_err(|_| Errno::ECHILD)? + } + Err(Errno::ECHILD) } diff --git a/crates/containerd-shim-wasm/src/test/signals.rs b/crates/containerd-shim-wasm/src/test/signals.rs index b57f9575b..23320349d 100644 --- a/crates/containerd-shim-wasm/src/test/signals.rs +++ b/crates/containerd-shim-wasm/src/test/signals.rs @@ -28,6 +28,7 @@ use std::time::Duration; use anyhow::{Result, bail}; use containerd_shim_wasm_test_modules::HELLO_WORLD; +use tokio_util::task::TaskTracker; use crate::container::{Engine, Instance, RuntimeContext}; use crate::testing::WasiTest; @@ -77,70 +78,68 @@ impl Drop for KillGuard { } } -#[test] -fn test_handling_signals() -> Result<()> { +#[tokio::test] +async fn test_handling_signals() -> Result<()> { zygote::Zygote::global(); - // use a thread scope to ensure we join all threads at the end - std::thread::scope(|s| -> Result<()> { - let mut containers = vec![]; + let mut containers = vec![]; - for i in 0..20 { - let builder = WasiTest::::builder()? - .with_name(format!("test-{i}")) - .with_start_fn(format!("test-{i}")) - .with_wasm(HELLO_WORLD)?; + for i in 0..20 { + let builder = WasiTest::::builder()? + .with_name(format!("test-{i}")) + .with_start_fn(format!("test-{i}")) + .with_wasm(HELLO_WORLD)?; - // In CI /proc/self/fd/1 doesn't seem to be available - let builder = match canonicalize("/proc/self/fd/1") { - Ok(stdout) => builder.with_stdout(stdout)?, - _ => builder, - }; + // In CI /proc/self/fd/1 doesn't seem to be available + let builder = match canonicalize("/proc/self/fd/1") { + Ok(stdout) => builder.with_stdout(stdout)?, + _ => builder, + }; - let container = builder.build()?; - containers.push(Arc::new(container)); - } + let container = builder.build().await?; + containers.push(Arc::new(container)); + } - let _guard: Vec<_> = containers.iter().cloned().map(KillGuard).collect(); + let _guard: Vec<_> = containers.iter().cloned().map(KillGuard).collect(); - for container in containers.iter() { - container.start()?; - } + for container in containers.iter() { + container.start().await?; + } - let (tx, rx) = channel(); - - for (i, container) in containers.iter().cloned().enumerate() { - let tx = tx.clone(); - s.spawn(move || -> anyhow::Result<()> { - println!("shim> waiting for container {i}"); - let (code, ..) = container.wait(Duration::from_secs(10000))?; - println!("shim> container test-{i} exited with code {code}"); - tx.send(i)?; - Ok(()) - }); - } + let (tx, rx) = channel(); + + let tasks = TaskTracker::new(); + for (i, container) in containers.iter().cloned().enumerate() { + let tx = tx.clone(); + tasks.spawn(async move { + println!("shim> waiting for container {i}"); + let (code, ..) = container.wait(Duration::from_secs(10000)).await?; + println!("shim> container test-{i} exited with code {code}"); + tx.send(i)?; + Ok::<_, anyhow::Error>(()) + }); + } - 'outer: for (i, container) in containers.iter().enumerate() { - for _ in 0..100 { - let stderr = container.read_stderr()?.unwrap_or_default(); - if stderr.contains("ready") { - continue 'outer; - } - sleep(Duration::from_millis(1)); + 'outer: for (i, container) in containers.iter().enumerate() { + for _ in 0..100 { + let stderr = container.read_stderr()?.unwrap_or_default(); + if stderr.contains("ready") { + continue 'outer; } - bail!("timeout waiting for container test-{i}"); + sleep(Duration::from_millis(1)); } + bail!("timeout waiting for container test-{i}"); + } - println!("shim> all containers ready"); + println!("shim> all containers ready"); - for (i, container) in containers.iter().enumerate() { - println!("shim> sending ctrl-c to container test-{i}"); - let _ = container.ctrl_c()?; - let id = rx.recv_timeout(Duration::from_secs(5))?; - println!("shim> received exit from container test-{id} (expected test-{i})"); - assert_eq!(id, i); - } + for (i, container) in containers.iter().enumerate() { + println!("shim> sending ctrl-c to container test-{i}"); + let _ = container.ctrl_c().await?; + let id = rx.recv_timeout(Duration::from_secs(5))?; + println!("shim> received exit from container test-{id} (expected test-{i})"); + assert_eq!(id, i); + } - Ok(()) - }) + Ok(()) } diff --git a/crates/containerd-shim-wasm/src/testing.rs b/crates/containerd-shim-wasm/src/testing.rs index 00fa93b77..f1be9e9dc 100644 --- a/crates/containerd-shim-wasm/src/testing.rs +++ b/crates/containerd-shim-wasm/src/testing.rs @@ -185,7 +185,7 @@ where )) } - pub fn build(self) -> Result> { + pub async fn build(self) -> Result> { let tempdir = self.tempdir; let dir = tempdir.path(); @@ -225,7 +225,7 @@ where ..Default::default() }; - let instance = WasiInstance::new(self.container_name, &cfg)?; + let instance = WasiInstance::new(self.container_name, &cfg).await?; Ok(WasiTest { instance, tempdir }) } } @@ -242,44 +242,44 @@ where &self.instance } - pub fn start(&self) -> Result<&Self> { + pub async fn start(&self) -> Result<&Self> { log::info!("starting wasi test"); - let pid = self.instance.start()?; + let pid = self.instance.start().await?; log::info!("wasi test pid {pid}"); Ok(self) } - pub fn delete(&self) -> Result<&Self> { + pub async fn delete(&self) -> Result<&Self> { log::info!("deleting wasi test"); - self.instance.delete()?; + self.instance.delete().await?; Ok(self) } - pub fn ctrl_c(&self) -> Result<&Self> { + pub async fn ctrl_c(&self) -> Result<&Self> { log::info!("sending SIGINT"); - self.instance.kill(SIGINT as u32)?; + self.instance.kill(SIGINT as u32).await?; Ok(self) } - pub fn terminate(&self) -> Result<&Self> { + pub async fn terminate(&self) -> Result<&Self> { log::info!("sending SIGTERM"); - self.instance.kill(SIGTERM as u32)?; + self.instance.kill(SIGTERM as u32).await?; Ok(self) } - pub fn kill(&self) -> Result<&Self> { + pub async fn kill(&self) -> Result<&Self> { log::info!("sending SIGKILL"); - self.instance.kill(SIGKILL as u32)?; + self.instance.kill(SIGKILL as u32).await?; Ok(self) } - pub fn wait(&self, t: Duration) -> Result<(u32, String, String)> { + pub async fn wait(&self, t: Duration) -> Result<(u32, String, String)> { log::info!("waiting wasi test"); - let (status, _) = match self.instance.wait().with_timeout(t).block_on() { + let (status, _) = match self.instance.wait().with_timeout(t).await { Some(res) => res, None => { - self.instance.kill(SIGKILL)?; + self.instance.kill(SIGKILL).await?; bail!("timeout while waiting for module to finish"); } }; @@ -287,7 +287,7 @@ where let stdout = self.read_stdout()?.unwrap_or_default(); let stderr = self.read_stderr()?.unwrap_or_default(); - self.instance.delete()?; + self.instance.delete().await?; log::info!("wasi test status is {status}"); diff --git a/crates/containerd-shim-wasmedge/Cargo.toml b/crates/containerd-shim-wasmedge/Cargo.toml index e059f2a66..16178cd88 100644 --- a/crates/containerd-shim-wasmedge/Cargo.toml +++ b/crates/containerd-shim-wasmedge/Cargo.toml @@ -15,6 +15,7 @@ wasmedge-sdk = { version = "0.14.0", default-features = false } containerd-shim-wasm = { workspace = true, features = ["testing"] } libc = { workspace = true } serial_test = { workspace = true } +tokio = { workspace = true } [features] default = ["standalone", "static"] diff --git a/crates/containerd-shim-wasmedge/src/tests.rs b/crates/containerd-shim-wasmedge/src/tests.rs index cb117f170..0c86921f0 100644 --- a/crates/containerd-shim-wasmedge/src/tests.rs +++ b/crates/containerd-shim-wasmedge/src/tests.rs @@ -7,21 +7,28 @@ use serial_test::serial; use crate::instance::WasmEdgeInstance as WasiInstance; -#[test] +#[tokio::test] #[serial] -fn test_delete_after_create() -> anyhow::Result<()> { - WasiTest::::builder()?.build()?.delete()?; +async fn test_delete_after_create() -> anyhow::Result<()> { + WasiTest::::builder()? + .build() + .await? + .delete() + .await?; Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world() -> anyhow::Result<()> { +async fn test_hello_world() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -29,14 +36,20 @@ fn test_hello_world() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world_oci() -> anyhow::Result<()> { +async fn test_hello_world_oci() -> anyhow::Result<()> { let (builder, _oci_cleanup) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? .as_oci_image(None, None)?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -44,15 +57,18 @@ fn test_hello_world_oci() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_custom_entrypoint() -> anyhow::Result<()> { +async fn test_custom_entrypoint() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_start_fn("foo") .with_wasm(CUSTOM_ENTRYPOINT)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -60,42 +76,51 @@ fn test_custom_entrypoint() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_unreachable() -> anyhow::Result<()> { +async fn test_unreachable() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(UNREACHABLE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_ne!(exit_code, 0); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_exit_code() -> anyhow::Result<()> { +async fn test_exit_code() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(EXIT_CODE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 42); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_seccomp() -> anyhow::Result<()> { +async fn test_seccomp() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(SECCOMP)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout.trim(), "current working dir: /"); @@ -103,14 +128,17 @@ fn test_seccomp() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_has_default_devices() -> anyhow::Result<()> { +async fn test_has_default_devices() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(HAS_DEFAULT_DEVICES)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); diff --git a/crates/containerd-shim-wasmer/src/tests.rs b/crates/containerd-shim-wasmer/src/tests.rs index 7dae636c0..cc2196e42 100644 --- a/crates/containerd-shim-wasmer/src/tests.rs +++ b/crates/containerd-shim-wasmer/src/tests.rs @@ -7,21 +7,28 @@ use serial_test::serial; use crate::instance::WasmerInstance as WasiInstance; -#[test] +#[tokio::test] #[serial] -fn test_delete_after_create() -> anyhow::Result<()> { - WasiTest::::builder()?.build()?.delete()?; +async fn test_delete_after_create() -> anyhow::Result<()> { + WasiTest::::builder()? + .build() + .await? + .delete() + .await?; Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world() -> anyhow::Result<()> { +async fn test_hello_world() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -29,14 +36,20 @@ fn test_hello_world() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world_oci() -> anyhow::Result<()> { +async fn test_hello_world_oci() -> anyhow::Result<()> { let (builder, _oci_cleanup) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? .as_oci_image(None, None)?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -44,15 +57,18 @@ fn test_hello_world_oci() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_custom_entrypoint() -> anyhow::Result<()> { +async fn test_custom_entrypoint() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_start_fn("foo") .with_wasm(CUSTOM_ENTRYPOINT)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -60,42 +76,51 @@ fn test_custom_entrypoint() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_unreachable() -> anyhow::Result<()> { +async fn test_unreachable() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(UNREACHABLE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_ne!(exit_code, 0); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_exit_code() -> anyhow::Result<()> { +async fn test_exit_code() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(EXIT_CODE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 42); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_seccomp() -> anyhow::Result<()> { +async fn test_seccomp() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(SECCOMP)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout.trim(), "current working dir: /"); @@ -103,14 +128,17 @@ fn test_seccomp() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_has_default_devices() -> anyhow::Result<()> { +async fn test_has_default_devices() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(HAS_DEFAULT_DEVICES)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); diff --git a/crates/containerd-shim-wasmtime/Cargo.toml b/crates/containerd-shim-wasmtime/Cargo.toml index ad0ed3e91..a16803ded 100644 --- a/crates/containerd-shim-wasmtime/Cargo.toml +++ b/crates/containerd-shim-wasmtime/Cargo.toml @@ -19,7 +19,7 @@ wasmtime-wasi-http = { workspace = true } [dev-dependencies] containerd-shim-wasm = { workspace = true, features = ["testing"] } serial_test = { workspace = true } -reqwest = { version = "0.12", default-features=false, features = ["blocking"] } +reqwest = { version = "0.12" } [[bin]] name = "containerd-shim-wasmtime-v1" diff --git a/crates/containerd-shim-wasmtime/src/tests.rs b/crates/containerd-shim-wasmtime/src/tests.rs index 1ffb3ad4c..5eba0f660 100644 --- a/crates/containerd-shim-wasmtime/src/tests.rs +++ b/crates/containerd-shim-wasmtime/src/tests.rs @@ -12,21 +12,28 @@ use crate::instance::WasmtimeEngine; // https://github.com/containerd/runwasi/issues/357 type WasmtimeTestInstance = Instance; -#[test] +#[tokio::test] #[serial] -fn test_delete_after_create() -> anyhow::Result<()> { - WasiTest::::builder()?.build()?.delete()?; +async fn test_delete_after_create() -> anyhow::Result<()> { + WasiTest::::builder()? + .build() + .await? + .delete() + .await?; Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world() -> anyhow::Result<()> { +async fn test_hello_world() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -34,14 +41,20 @@ fn test_hello_world() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world_oci() -> anyhow::Result<()> { +async fn test_hello_world_oci() -> anyhow::Result<()> { let (builder, _oci_cleanup) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? .as_oci_image(None, None)?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -49,9 +62,9 @@ fn test_hello_world_oci() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world_oci_uses_precompiled() -> anyhow::Result<()> { +async fn test_hello_world_oci_uses_precompiled() -> anyhow::Result<()> { let (builder, _oci_cleanup1) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? .as_oci_image( @@ -59,7 +72,13 @@ fn test_hello_world_oci_uses_precompiled() -> anyhow::Result<()> { Some("c1".to_string()), )?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -79,7 +98,13 @@ fn test_hello_world_oci_uses_precompiled() -> anyhow::Result<()> { Some("c2".to_string()), )?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -87,9 +112,9 @@ fn test_hello_world_oci_uses_precompiled() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_hello_world_oci_uses_precompiled_when_content_removed() -> anyhow::Result<()> { +async fn test_hello_world_oci_uses_precompiled_when_content_removed() -> anyhow::Result<()> { let (builder, _oci_cleanup1) = WasiTest::::builder()? .with_wasm(HELLO_WORLD)? .as_oci_image( @@ -97,7 +122,13 @@ fn test_hello_world_oci_uses_precompiled_when_content_removed() -> anyhow::Resul Some("c1".to_string()), )?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -119,7 +150,13 @@ fn test_hello_world_oci_uses_precompiled_when_content_removed() -> anyhow::Resul Some("c2".to_string()), )?; - let (exit_code, stdout, _) = builder.build()?.start()?.wait(Duration::from_secs(10))?; + let (exit_code, stdout, _) = builder + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -127,15 +164,18 @@ fn test_hello_world_oci_uses_precompiled_when_content_removed() -> anyhow::Resul Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_custom_entrypoint() -> anyhow::Result<()> { +async fn test_custom_entrypoint() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_start_fn("foo") .with_wasm(CUSTOM_ENTRYPOINT)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "hello world\n"); @@ -143,42 +183,51 @@ fn test_custom_entrypoint() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_unreachable() -> anyhow::Result<()> { +async fn test_unreachable() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(UNREACHABLE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_ne!(exit_code, 0); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_exit_code() -> anyhow::Result<()> { +async fn test_exit_code() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(EXIT_CODE)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 42); Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_seccomp() -> anyhow::Result<()> { +async fn test_seccomp() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(SECCOMP)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout.trim(), "current working dir: /"); @@ -186,14 +235,17 @@ fn test_seccomp() -> anyhow::Result<()> { Ok(()) } -#[test] +#[tokio::test] #[serial] -fn test_has_default_devices() -> anyhow::Result<()> { +async fn test_has_default_devices() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(HAS_DEFAULT_DEVICES)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); @@ -205,15 +257,18 @@ fn test_has_default_devices() -> anyhow::Result<()> { // The current limitation is that there is no way to pass arguments // to the exported function. // Issue that tracks this: https://github.com/containerd/runwasi/issues/414 -#[test] +#[tokio::test] #[serial] -fn test_simple_component() -> anyhow::Result<()> { +async fn test_simple_component() -> anyhow::Result<()> { let (exit_code, _, _) = WasiTest::::builder()? .with_wasm(SIMPLE_COMPONENT)? .with_start_fn("thunk") - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); @@ -228,14 +283,17 @@ fn test_simple_component() -> anyhow::Result<()> { // The wasm component is built and copied over from // https://github.com/Mossaka/wasm-component-hello-world. See // README.md for how to build the component. -#[test] +#[tokio::test] #[serial] -fn test_wasip2_component() -> anyhow::Result<()> { +async fn test_wasip2_component() -> anyhow::Result<()> { let (exit_code, stdout, _) = WasiTest::::builder()? .with_wasm(COMPONENT_HELLO_WORLD)? - .build()? - .start()? - .wait(Duration::from_secs(10))?; + .build() + .await? + .start() + .await? + .wait(Duration::from_secs(10)) + .await?; assert_eq!(exit_code, 0); assert_eq!(stdout, "Hello, world!\n"); @@ -250,24 +308,25 @@ fn test_wasip2_component() -> anyhow::Result<()> { // // The wasm component is built using cargo component as illustrated in the following example:: // https://opensource.microsoft.com/blog/2024/09/25/distributing-webassembly-components-using-oci-registries/ -#[test] +#[tokio::test] #[serial] -fn test_wasip2_component_http_proxy() -> anyhow::Result<()> { +async fn test_wasip2_component_http_proxy() -> anyhow::Result<()> { let srv = WasiTest::::builder()? .with_wasm(HELLO_WASI_HTTP)? .with_host_network() - .build()?; + .build() + .await?; - let srv = srv.start()?; - let response = http_get(); + let srv = srv.start().await?; + let response = http_get().await; let response = response.expect("Server did not start in time"); assert!(response.status().is_success()); - let body = response.text().unwrap(); + let body = response.text().await.unwrap(); assert_eq!(body, "Hello, this is your first wasi:http/proxy world!\n"); - let (exit_code, _, _) = srv.ctrl_c()?.wait(Duration::from_secs(5))?; + let (exit_code, _, _) = srv.ctrl_c().await?.wait(Duration::from_secs(5)).await?; assert_eq!(exit_code, 0); Ok(()) @@ -276,65 +335,67 @@ fn test_wasip2_component_http_proxy() -> anyhow::Result<()> { // The wasm component is built using componentize-dotnet as illustrated in the following example:: // https://bytecodealliance.org/articles/simplifying-components-for-dotnet-developers-with-componentize-dotnet // this ensures we are able to use wasm built from other languages https://github.com/containerd/runwasi/pull/723 -#[test] +#[tokio::test] #[serial] -fn test_wasip2_component_http_proxy_csharp() -> anyhow::Result<()> { +async fn test_wasip2_component_http_proxy_csharp() -> anyhow::Result<()> { let srv = WasiTest::::builder()? .with_wasm(HELLO_WASI_HTTP_CSHARP)? .with_host_network() - .build()?; + .build() + .await?; - let srv = srv.start()?; + let srv = srv.start().await?; // dotnet takes a bit longer to start up // Todo: find out why this doesn't happen in wasmtime directly - let response = http_get_with_backoff_secs(2); + let response = http_get_with_backoff_secs(2).await; let response = response.expect("Server did not start in time"); assert!(response.status().is_success()); - let body = response.text().unwrap(); + let body = response.text().await.unwrap(); assert_eq!(body, "Hello, from C#!"); - let (exit_code, _, _) = srv.ctrl_c()?.wait(Duration::from_secs(5))?; + let (exit_code, _, _) = srv.ctrl_c().await?.wait(Duration::from_secs(5)).await?; assert_eq!(exit_code, 0); Ok(()) } // Test that the shim can terminate component targeting wasi:http/proxy by sending SIGTERM. -#[test] +#[tokio::test] #[serial] -fn test_wasip2_component_http_proxy_force_shutdown() -> anyhow::Result<()> { +async fn test_wasip2_component_http_proxy_force_shutdown() -> anyhow::Result<()> { let srv = WasiTest::::builder()? .with_wasm(HELLO_WASI_HTTP)? .with_host_network() - .build()?; + .build() + .await?; - let srv = srv.start()?; - assert!(http_get().unwrap().status().is_success()); + let srv = srv.start().await?; + assert!(http_get().await.unwrap().status().is_success()); // Send SIGTERM - let (exit_code, _, _) = srv.terminate()?.wait(Duration::from_secs(5))?; + let (exit_code, _, _) = srv.terminate().await?.wait(Duration::from_secs(5)).await?; // The exit code indicates that the process did not exit cleanly assert_eq!(exit_code, 128 + libc::SIGTERM as u32); Ok(()) } -fn http_get() -> reqwest::Result { - http_get_with_backoff_secs(1) +async fn http_get() -> reqwest::Result { + http_get_with_backoff_secs(1).await } // Helper method to make a `GET` request -fn http_get_with_backoff_secs(backoff: u64) -> reqwest::Result { +async fn http_get_with_backoff_secs(backoff: u64) -> reqwest::Result { const MAX_ATTEMPTS: u32 = 10; let backoff_duration: Duration = Duration::from_secs(backoff); let mut attempts = 0; loop { - match reqwest::blocking::get("http://127.0.0.1:8080") { + match reqwest::get("http://127.0.0.1:8080").await { Ok(resp) => break Ok(resp), Err(err) if attempts == MAX_ATTEMPTS => break Err(err), Err(_) => {