diff --git a/Cargo.lock b/Cargo.lock index 5f50358..80b1729 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -840,6 +840,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb60e7409f34ef959985bc9d9c5ee8f5db24ee46ed9775850548021710f807f" +dependencies = [ + "autocfg", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -974,6 +983,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.1" @@ -1338,7 +1353,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.1", ] [[package]] @@ -1392,7 +1407,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown", + "hashbrown 0.15.1", ] [[package]] @@ -1887,6 +1902,7 @@ dependencies = [ "clap", "csv", "flate2", + "fs-err", "futures-util", "hex", "md-5", @@ -1899,6 +1915,7 @@ dependencies = [ "thiserror", "time", "tokio", + "tokio-util", ] [[package]] @@ -2316,6 +2333,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 93aeea3..5782bb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ aws-smithy-runtime-api = "1.7.3" clap = { version = "4.5.21", default-features = false, features = ["derive", "error-context", "help", "std", "suggestions", "usage", "wrap_help"] } csv = "1.3.1" flate2 = "1.0.35" +fs-err = "3.0.0" futures-util = "0.3.31" hex = "0.4.3" md-5 = "0.10.6" @@ -34,6 +35,7 @@ tempfile = "3.14.0" thiserror = "2.0.3" time = { version = "0.3.36", features = ["macros", "parsing", "serde"] } tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] } +tokio-util = { version = "0.7.12", features = ["rt"] } [dev-dependencies] rstest = { version = "0.23.0", default-features = false } diff --git a/src/asyncutil/lsg.rs b/src/asyncutil/lsg.rs new file mode 100644 index 0000000..6a0d5d2 --- /dev/null +++ b/src/asyncutil/lsg.rs @@ -0,0 +1,81 @@ +use futures_util::Stream; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex, PoisonError}; +use std::task::{Context, Poll}; +use tokio::sync::{ + mpsc::{channel, Receiver, Sender}, + Semaphore, +}; +use tokio_util::sync::CancellationToken; + +/// A task group with the following properties: +/// +/// - No more than a certain number of tasks are ever active at once. +/// +/// - Each task is passed a `CancellationToken` that can be used for graceful +/// shutdown. +/// +/// - `LimitedShutdownGroup` is a `Stream` of the return values of the tasks +/// (which must all be `T`). +/// +/// - `shutdown()` cancels the cancellation token and prevents any further +/// pending tasks from running. +#[derive(Debug)] +pub(crate) struct LimitedShutdownGroup { + sender: Mutex>>, + receiver: Receiver, + semaphore: Arc, + token: CancellationToken, +} + +impl LimitedShutdownGroup { + pub(crate) fn new(limit: usize) -> Self { + let (sender, receiver) = channel(limit); + LimitedShutdownGroup { + sender: Mutex::new(Some(sender)), + receiver, + semaphore: Arc::new(Semaphore::new(limit)), + token: CancellationToken::new(), + } + } + + pub(crate) fn spawn(&self, func: F) + where + F: FnOnce(CancellationToken) -> Fut, + Fut: Future + Send + 'static, + { + let sender = { + let s = self.sender.lock().unwrap_or_else(PoisonError::into_inner); + (*s).clone() + }; + if let Some(sender) = sender { + let future = func(self.token.clone()); + let sem = Arc::clone(&self.semaphore); + tokio::spawn(async move { + if let Ok(_permit) = sem.acquire().await { + let _ = sender.send(future.await).await; + } + }); + } + } + + pub(crate) fn shutdown(&self) { + { + let mut s = self.sender.lock().unwrap_or_else(PoisonError::into_inner); + *s = None; + } + self.semaphore.close(); + self.token.cancel(); + } +} + +impl Stream for LimitedShutdownGroup { + type Item = T; + + /// Poll for one of the tasks in the group to complete and return its + /// return value. + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.receiver.poll_recv(cx) + } +} diff --git a/src/asyncutil/mod.rs b/src/asyncutil/mod.rs new file mode 100644 index 0000000..be98698 --- /dev/null +++ b/src/asyncutil/mod.rs @@ -0,0 +1,2 @@ +mod lsg; +pub(crate) use self::lsg::*; diff --git a/src/inventory.rs b/src/inventory/item.rs similarity index 92% rename from src/inventory.rs rename to src/inventory/item.rs index 96013e3..1a467d6 100644 --- a/src/inventory.rs +++ b/src/inventory/item.rs @@ -1,3 +1,4 @@ +use crate::s3::S3Location; use serde::{de, Deserialize}; use std::fmt; use thiserror::Error; @@ -6,12 +7,23 @@ use time::OffsetDateTime; #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] #[serde(try_from = "RawInventoryItem")] pub(crate) struct InventoryItem { - bucket: String, - key: String, - version_id: String, - is_latest: bool, - last_modified_date: OffsetDateTime, - details: ItemDetails, + pub(crate) bucket: String, + pub(crate) key: String, + pub(crate) version_id: String, + pub(crate) is_latest: bool, + pub(crate) last_modified_date: OffsetDateTime, + pub(crate) details: ItemDetails, +} + +impl InventoryItem { + pub(crate) fn url(&self) -> S3Location { + S3Location::new(self.bucket.clone(), self.key.clone()) + .with_version_id(self.version_id.clone()) + } + + pub(crate) fn is_deleted(&self) -> bool { + self.details == ItemDetails::Deleted + } } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/src/inventory/list.rs b/src/inventory/list.rs new file mode 100644 index 0000000..77276e9 --- /dev/null +++ b/src/inventory/list.rs @@ -0,0 +1,41 @@ +use super::item::InventoryItem; +use crate::s3::S3Location; +use flate2::bufread::GzDecoder; +use std::fs::File; +use std::io::BufReader; +use thiserror::Error; + +pub(crate) struct InventoryList { + url: S3Location, + inner: csv::DeserializeRecordsIntoIter>, InventoryItem>, +} + +impl InventoryList { + pub(crate) fn from_gzip_csv_file(url: S3Location, f: File) -> InventoryList { + InventoryList { + url, + inner: csv::ReaderBuilder::new() + .has_headers(false) + .from_reader(GzDecoder::new(BufReader::new(f))) + .into_deserialize(), + } + } +} + +impl Iterator for InventoryList { + type Item = Result; + + fn next(&mut self) -> Option { + Some(self.inner.next()?.map_err(|source| InventoryListError { + url: self.url.clone(), + source, + })) + } +} + +#[derive(Debug, Error)] +#[error("failed to read entry from inventory list at {url}")] +pub(crate) struct InventoryListError { + url: S3Location, + source: csv::Error, +} diff --git a/src/inventory/mod.rs b/src/inventory/mod.rs new file mode 100644 index 0000000..e3760aa --- /dev/null +++ b/src/inventory/mod.rs @@ -0,0 +1,4 @@ +mod item; +mod list; +pub(crate) use self::item::*; +pub(crate) use self::list::*; diff --git a/src/main.rs b/src/main.rs index 17ea27e..7322989 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,14 @@ -#![allow(dead_code)] // XXX -#![allow(unused_imports)] // XXX -#![allow(clippy::todo)] // XXX +mod asyncutil; mod inventory; mod manifest; mod s3; +mod syncer; mod timestamps; use crate::s3::{get_bucket_region, S3Client, S3Location}; +use crate::syncer::Syncer; use crate::timestamps::DateMaybeHM; use clap::Parser; +use std::num::NonZeroUsize; use std::path::PathBuf; #[derive(Clone, Debug, Eq, Parser, PartialEq)] @@ -16,6 +17,12 @@ struct Arguments { #[arg(short, long)] date: Option, + #[arg(short = 'I', long, default_value = "20")] + inventory_jobs: NonZeroUsize, + + #[arg(short = 'O', long, default_value = "20")] + object_jobs: NonZeroUsize, + inventory_base: S3Location, outdir: PathBuf, @@ -27,15 +34,7 @@ async fn main() -> anyhow::Result<()> { let region = get_bucket_region(args.inventory_base.bucket()).await?; let client = S3Client::new(region, args.inventory_base).await?; let manifest = client.get_manifest_for_date(args.date).await?; - for fspec in &manifest.files { - // TODO: Add to pool of concurrent download tasks? - client.download_inventory_csv(fspec).await?; - // TODO: For each entry in CSV: - // - Download object (in a task pool) - // - Manage object metadata and old versions - // - Handle concurrent downloads of the same key - todo!() - } - // TODO: Handle error recovery and Ctrl-C + let syncer = Syncer::new(client, args.outdir, args.inventory_jobs, args.object_jobs); + syncer.run(manifest).await?; Ok(()) } diff --git a/src/s3/location.rs b/src/s3/location.rs index b45d5da..9ce1332 100644 --- a/src/s3/location.rs +++ b/src/s3/location.rs @@ -6,11 +6,16 @@ use thiserror::Error; pub(crate) struct S3Location { bucket: String, key: String, + version_id: Option, } impl S3Location { pub(crate) fn new(bucket: String, key: String) -> S3Location { - S3Location { bucket, key } + S3Location { + bucket, + key, + version_id: None, + } } pub(crate) fn bucket(&self) -> &str { @@ -21,8 +26,13 @@ impl S3Location { &self.key } + pub(crate) fn version_id(&self) -> Option<&str> { + self.version_id.as_deref() + } + pub(crate) fn join(&self, suffix: &str) -> S3Location { let mut joined = self.clone(); + joined.version_id = None; if !joined.key.ends_with('/') { joined.key.push('/'); } @@ -34,14 +44,26 @@ impl S3Location { S3Location { bucket: self.bucket.clone(), key: key.into(), + version_id: None, + } + } + + pub(crate) fn with_version_id>(&self, version_id: S) -> S3Location { + S3Location { + bucket: self.bucket.clone(), + key: self.key.clone(), + version_id: Some(version_id.into()), } } } impl fmt::Display for S3Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // TODO: Should the key be percent-encoded? - write!(f, "s3://{}/{}", self.bucket, self.key) + write!(f, "s3://{}/{}", self.bucket, self.key)?; + if let Some(ref v) = self.version_id { + write!(f, "?versionId={v}")?; + } + Ok(()) } } @@ -63,10 +85,10 @@ impl FromStr for S3Location { if bucket.is_empty() || !bucket.chars().all(is_bucket_char) { return Err(S3LocationError::BadBucket); } - // TODO: Does the key need to be percent-decoded? Ok(S3Location { bucket: bucket.to_owned(), key: key.to_owned(), + version_id: None, }) } } diff --git a/src/s3/mod.rs b/src/s3/mod.rs index e2fd337..bb00f0c 100644 --- a/src/s3/mod.rs +++ b/src/s3/mod.rs @@ -2,6 +2,7 @@ mod location; mod streams; pub(crate) use self::location::S3Location; use self::streams::{ListManifestDates, ListObjectsError}; +use crate::inventory::InventoryList; use crate::manifest::CsvManifest; use crate::manifest::FileSpec; use crate::timestamps::{Date, DateHM, DateMaybeHM}; @@ -11,7 +12,6 @@ use aws_sdk_s3::{ Client, }; use aws_smithy_runtime_api::client::{orchestrator::HttpResponse, result::SdkError}; -use flate2::bufread::GzDecoder; use futures_util::TryStreamExt; use md5::{Digest, Md5}; use std::fs::File; @@ -19,8 +19,6 @@ use std::io::{BufReader, BufWriter, Seek, Write}; use std::path::{Path, PathBuf}; use thiserror::Error; -type CsvReader = csv::Reader>>; - #[derive(Debug)] pub(crate) struct S3Client { inner: Client, @@ -106,16 +104,14 @@ impl S3Client { } async fn get_object(&self, url: &S3Location) -> Result { - self.inner - .get_object() - .bucket(url.bucket()) - .key(url.key()) - .send() - .await - .map_err(|source| GetError { - url: url.to_owned(), - source, - }) + let mut op = self.inner.get_object().bucket(url.bucket()).key(url.key()); + if let Some(v) = url.version_id() { + op = op.version_id(v); + } + op.send().await.map_err(|source| GetError { + url: url.to_owned(), + source, + }) } pub(crate) async fn get_manifest(&self, when: DateHM) -> Result { @@ -161,8 +157,8 @@ impl S3Client { pub(crate) async fn download_inventory_csv( &self, - fspec: &FileSpec, - ) -> Result { + fspec: FileSpec, + ) -> Result { let fname = fspec .key .rsplit_once('/') @@ -172,13 +168,10 @@ impl S3Client { self.make_dl_tempfile(&PathBuf::from(format!("data/{fname}.csv.gz")), &url)?; self.download_object(&url, Some(&fspec.md5_checksum), &outfile) .await?; - // TODO: Verify file size? - Ok(csv::ReaderBuilder::new() - .has_headers(false) - .from_reader(GzDecoder::new(BufReader::new(outfile)))) + Ok(InventoryList::from_gzip_csv_file(url, outfile)) } - async fn download_object( + pub(crate) async fn download_object( &self, url: &S3Location, // `md5_digest` must be a 32-character lowercase hexadecimal string diff --git a/src/syncer.rs b/src/syncer.rs new file mode 100644 index 0000000..21e768a --- /dev/null +++ b/src/syncer.rs @@ -0,0 +1,190 @@ +use crate::asyncutil::LimitedShutdownGroup; +use crate::inventory::InventoryItem; +use crate::manifest::CsvManifest; +use crate::s3::S3Client; +use anyhow::Context; +use futures_util::StreamExt; +use std::fmt; +use std::fs::File; +use std::num::NonZeroUsize; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::sync::mpsc::channel; +use tokio_util::sync::CancellationToken; + +#[derive(Debug)] +pub(crate) struct Syncer { + client: Arc, + outdir: PathBuf, + inventory_jobs: NonZeroUsize, + object_jobs: NonZeroUsize, +} + +impl Syncer { + pub(crate) fn new( + client: S3Client, + outdir: PathBuf, + inventory_jobs: NonZeroUsize, + object_jobs: NonZeroUsize, + ) -> Arc { + Arc::new(Syncer { + client: Arc::new(client), + outdir, + inventory_jobs, + object_jobs, + }) + } + + pub(crate) async fn run(self: &Arc, manifest: CsvManifest) -> Result<(), MultiError> { + let mut inventory_dl_pool = LimitedShutdownGroup::new(self.inventory_jobs.get()); + let mut object_dl_pool = LimitedShutdownGroup::new(self.object_jobs.get()); + let (obj_sender, mut obj_receiver) = channel(self.inventory_jobs.get()); + + for fspec in manifest.files { + let clnt = self.client.clone(); + let sender = obj_sender.clone(); + inventory_dl_pool.spawn(move |_| async move { + let itemlist = clnt.download_inventory_csv(fspec).await?; + for item in itemlist { + let _ = sender.send(item?).await; + } + Ok(()) + }); + } + + let mut errors = Vec::new(); + let mut inventory_pool_closed = false; + let mut object_pool_closed = false; + let mut all_objects_txed = false; + loop { + tokio::select! { + r = inventory_dl_pool.next(), if !inventory_pool_closed => { + match r { + Some(Ok(())) => (), + Some(Err(e)) => { + if !errors.is_empty() { + inventory_dl_pool.shutdown(); + object_dl_pool.shutdown(); + } + errors.push(e); + } + None => inventory_pool_closed = true, + } + } + r = object_dl_pool.next(), if !object_pool_closed => { + match r { + Some(Ok(())) => (), + Some(Err(e)) => { + if !errors.is_empty() { + inventory_dl_pool.shutdown(); + object_dl_pool.shutdown(); + } + errors.push(e); + } + None => object_pool_closed = true, + } + } + r = obj_receiver.recv(), if !all_objects_txed => { + if let Some(item) = r { + let this = self.clone(); + object_dl_pool + .spawn(move |token| async move { this.process_item(item, token).await }); + } else { + all_objects_txed = true; + } + } + else => break, + } + } + + if !errors.is_empty() { + Err(MultiError(errors)) + } else { + Ok(()) + } + } + + async fn process_item( + self: &Arc, + item: InventoryItem, + token: CancellationToken, + ) -> anyhow::Result<()> { + if token.is_cancelled() { + return Ok(()); + } + if item.is_deleted() || !item.is_latest { + // TODO + return Ok(()); + } + let url = item.url(); + let outpath = self.outdir.join(&item.key); + if let Some(p) = outpath.parent() { + fs_err::create_dir_all(p)?; + } + // TODO: Download to a temp file and then move + let outfile = File::create(&outpath) + .with_context(|| format!("failed to open output file {}", outpath.display()))?; + match token + .run_until_cancelled(self.client.download_object( + &url, + item.details.md5_digest(), + &outfile, + )) + .await + { + Some(Ok(())) => Ok(()), + Some(Err(e)) => { + // TODO: Warn on failure? + let _ = self.cleanup_download_path(&outpath); + Err(e.into()) + } + None => { + self.cleanup_download_path(&outpath)?; + Ok(()) + } + } + // TODO: Manage object metadata and old versions + // TODO: Handle concurrent downloads of the same key + } + + fn cleanup_download_path(&self, dlfile: &Path) -> std::io::Result<()> { + fs_err::remove_file(dlfile)?; + let p = dlfile.parent(); + while let Some(pp) = p { + if pp == self.outdir { + break; + } + if is_empty_dir(pp)? { + fs_err::remove_dir(pp)?; + } + } + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) struct MultiError(Vec); + +impl fmt::Display for MultiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + for e in &self.0 { + if !std::mem::replace(&mut first, false) { + writeln!(f, "\n---")?; + } + write!(f, "{e:?}")?; + } + Ok(()) + } +} + +impl std::error::Error for MultiError {} + +fn is_empty_dir(p: &Path) -> std::io::Result { + let mut iter = fs_err::read_dir(p)?; + match iter.next() { + None => Ok(true), + Some(Ok(_)) => Ok(false), + Some(Err(e)) => Err(e), + } +} diff --git a/src/timestamps/util.rs b/src/timestamps/util.rs index 6d1423a..5e475a7 100644 --- a/src/timestamps/util.rs +++ b/src/timestamps/util.rs @@ -1,6 +1,3 @@ -use std::num::ParseIntError; -use thiserror::Error; - #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub(super) struct Scanner<'a, E> { s: &'a str,