Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement concurrent downloads #23

Merged
merged 10 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ aws-smithy-runtime-api = "1.7.3"
clap = { version = "4.5.21", default-features = false, features = ["derive", "error-context", "help", "std", "suggestions", "usage", "wrap_help"] }
csv = "1.3.1"
flate2 = "1.0.35"
fs-err = "3.0.0"
futures-util = "0.3.31"
hex = "0.4.3"
md-5 = "0.10.6"
Expand All @@ -34,6 +35,7 @@ tempfile = "3.14.0"
thiserror = "2.0.3"
time = { version = "0.3.36", features = ["macros", "parsing", "serde"] }
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tokio-util = { version = "0.7.12", features = ["rt"] }

[dev-dependencies]
rstest = { version = "0.23.0", default-features = false }
Expand Down
81 changes: 81 additions & 0 deletions src/asyncutil/lsg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use futures_util::Stream;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex, PoisonError};
use std::task::{Context, Poll};
use tokio::sync::{
mpsc::{channel, Receiver, Sender},
Semaphore,
};
use tokio_util::sync::CancellationToken;

/// A task group with the following properties:
///
/// - No more than a certain number of tasks are ever active at once.
///
/// - Each task is passed a `CancellationToken` that can be used for graceful
/// shutdown.
///
/// - `LimitedShutdownGroup<T>` is a `Stream` of the return values of the tasks
/// (which must all be `T`).
///
/// - `shutdown()` cancels the cancellation token and prevents any further
/// pending tasks from running.
#[derive(Debug)]
pub(crate) struct LimitedShutdownGroup<T> {
sender: Mutex<Option<Sender<T>>>,
receiver: Receiver<T>,
semaphore: Arc<Semaphore>,
token: CancellationToken,
}

impl<T: Send + 'static> LimitedShutdownGroup<T> {
pub(crate) fn new(limit: usize) -> Self {
let (sender, receiver) = channel(limit);
LimitedShutdownGroup {
sender: Mutex::new(Some(sender)),
receiver,
semaphore: Arc::new(Semaphore::new(limit)),
token: CancellationToken::new(),
}
}

Check warning on line 41 in src/asyncutil/lsg.rs

View check run for this annotation

Codecov / codecov/patch

src/asyncutil/lsg.rs#L33-L41

Added lines #L33 - L41 were not covered by tests

pub(crate) fn spawn<F, Fut>(&self, func: F)
where
F: FnOnce(CancellationToken) -> Fut,
Fut: Future<Output = T> + Send + 'static,
{
let sender = {
let s = self.sender.lock().unwrap_or_else(PoisonError::into_inner);
(*s).clone()

Check warning on line 50 in src/asyncutil/lsg.rs

View check run for this annotation

Codecov / codecov/patch

src/asyncutil/lsg.rs#L43-L50

Added lines #L43 - L50 were not covered by tests
};
if let Some(sender) = sender {
let future = func(self.token.clone());
let sem = Arc::clone(&self.semaphore);
tokio::spawn(async move {
if let Ok(_permit) = sem.acquire().await {
let _ = sender.send(future.await).await;
}
});
}
}

Check warning on line 61 in src/asyncutil/lsg.rs

View check run for this annotation

Codecov / codecov/patch

src/asyncutil/lsg.rs#L52-L61

Added lines #L52 - L61 were not covered by tests

pub(crate) fn shutdown(&self) {
{
let mut s = self.sender.lock().unwrap_or_else(PoisonError::into_inner);
*s = None;
}
self.semaphore.close();
self.token.cancel();
}

Check warning on line 70 in src/asyncutil/lsg.rs

View check run for this annotation

Codecov / codecov/patch

src/asyncutil/lsg.rs#L63-L70

Added lines #L63 - L70 were not covered by tests
}

impl<T: Send + 'static> Stream for LimitedShutdownGroup<T> {
type Item = T;

/// Poll for one of the tasks in the group to complete and return its
/// return value.
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_recv(cx)
}

Check warning on line 80 in src/asyncutil/lsg.rs

View check run for this annotation

Codecov / codecov/patch

src/asyncutil/lsg.rs#L78-L80

Added lines #L78 - L80 were not covered by tests
}
2 changes: 2 additions & 0 deletions src/asyncutil/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod lsg;
pub(crate) use self::lsg::*;
24 changes: 18 additions & 6 deletions src/inventory.rs → src/inventory/item.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::s3::S3Location;
use serde::{de, Deserialize};
use std::fmt;
use thiserror::Error;
Expand All @@ -6,12 +7,23 @@
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[serde(try_from = "RawInventoryItem")]
pub(crate) struct InventoryItem {
bucket: String,
key: String,
version_id: String,
is_latest: bool,
last_modified_date: OffsetDateTime,
details: ItemDetails,
pub(crate) bucket: String,
pub(crate) key: String,
pub(crate) version_id: String,
pub(crate) is_latest: bool,
pub(crate) last_modified_date: OffsetDateTime,
pub(crate) details: ItemDetails,
}

impl InventoryItem {
pub(crate) fn url(&self) -> S3Location {
S3Location::new(self.bucket.clone(), self.key.clone())
.with_version_id(self.version_id.clone())
}

Check warning on line 22 in src/inventory/item.rs

View check run for this annotation

Codecov / codecov/patch

src/inventory/item.rs#L19-L22

Added lines #L19 - L22 were not covered by tests

pub(crate) fn is_deleted(&self) -> bool {
self.details == ItemDetails::Deleted
}

Check warning on line 26 in src/inventory/item.rs

View check run for this annotation

Codecov / codecov/patch

src/inventory/item.rs#L24-L26

Added lines #L24 - L26 were not covered by tests
}

#[derive(Clone, Debug, Eq, PartialEq)]
Expand Down
41 changes: 41 additions & 0 deletions src/inventory/list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use super::item::InventoryItem;
use crate::s3::S3Location;
use flate2::bufread::GzDecoder;
use std::fs::File;
use std::io::BufReader;
use thiserror::Error;

pub(crate) struct InventoryList {
url: S3Location,
inner: csv::DeserializeRecordsIntoIter<GzDecoder<BufReader<File>>, InventoryItem>,
}

impl InventoryList {
pub(crate) fn from_gzip_csv_file(url: S3Location, f: File) -> InventoryList {
InventoryList {
url,
inner: csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(GzDecoder::new(BufReader::new(f)))
.into_deserialize(),
}
}

Check warning on line 22 in src/inventory/list.rs

View check run for this annotation

Codecov / codecov/patch

src/inventory/list.rs#L14-L22

Added lines #L14 - L22 were not covered by tests
}

impl Iterator for InventoryList {
type Item = Result<InventoryItem, InventoryListError>;

fn next(&mut self) -> Option<Self::Item> {
Some(self.inner.next()?.map_err(|source| InventoryListError {
url: self.url.clone(),
source,
}))
}

Check warning on line 33 in src/inventory/list.rs

View check run for this annotation

Codecov / codecov/patch

src/inventory/list.rs#L28-L33

Added lines #L28 - L33 were not covered by tests
}

#[derive(Debug, Error)]
#[error("failed to read entry from inventory list at {url}")]
pub(crate) struct InventoryListError {
url: S3Location,
source: csv::Error,
}
4 changes: 4 additions & 0 deletions src/inventory/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod item;
mod list;
pub(crate) use self::item::*;
pub(crate) use self::list::*;
25 changes: 12 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#![allow(dead_code)] // XXX
#![allow(unused_imports)] // XXX
#![allow(clippy::todo)] // XXX
mod asyncutil;
mod inventory;
mod manifest;
mod s3;
mod syncer;
mod timestamps;
use crate::s3::{get_bucket_region, S3Client, S3Location};
use crate::syncer::Syncer;
use crate::timestamps::DateMaybeHM;
use clap::Parser;
use std::num::NonZeroUsize;
use std::path::PathBuf;

#[derive(Clone, Debug, Eq, Parser, PartialEq)]
Expand All @@ -16,6 +17,12 @@
#[arg(short, long)]
date: Option<DateMaybeHM>,

#[arg(short = 'I', long, default_value = "20")]
inventory_jobs: NonZeroUsize,

Check warning on line 21 in src/main.rs

View check run for this annotation

Codecov / codecov/patch

src/main.rs#L21

Added line #L21 was not covered by tests

#[arg(short = 'O', long, default_value = "20")]
object_jobs: NonZeroUsize,

Check warning on line 24 in src/main.rs

View check run for this annotation

Codecov / codecov/patch

src/main.rs#L24

Added line #L24 was not covered by tests

inventory_base: S3Location,

outdir: PathBuf,
Expand All @@ -27,15 +34,7 @@
let region = get_bucket_region(args.inventory_base.bucket()).await?;
let client = S3Client::new(region, args.inventory_base).await?;
let manifest = client.get_manifest_for_date(args.date).await?;
for fspec in &manifest.files {
// TODO: Add to pool of concurrent download tasks?
client.download_inventory_csv(fspec).await?;
// TODO: For each entry in CSV:
// - Download object (in a task pool)
// - Manage object metadata and old versions
// - Handle concurrent downloads of the same key
todo!()
}
// TODO: Handle error recovery and Ctrl-C
let syncer = Syncer::new(client, args.outdir, args.inventory_jobs, args.object_jobs);
syncer.run(manifest).await?;

Check warning on line 38 in src/main.rs

View check run for this annotation

Codecov / codecov/patch

src/main.rs#L37-L38

Added lines #L37 - L38 were not covered by tests
Ok(())
}
30 changes: 26 additions & 4 deletions src/s3/location.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
pub(crate) struct S3Location {
bucket: String,
key: String,
version_id: Option<String>,
}

impl S3Location {
pub(crate) fn new(bucket: String, key: String) -> S3Location {
S3Location { bucket, key }
S3Location {
bucket,
key,
version_id: None,
}

Check warning on line 18 in src/s3/location.rs

View check run for this annotation

Codecov / codecov/patch

src/s3/location.rs#L14-L18

Added lines #L14 - L18 were not covered by tests
}

pub(crate) fn bucket(&self) -> &str {
Expand All @@ -21,8 +26,13 @@
&self.key
}

pub(crate) fn version_id(&self) -> Option<&str> {
self.version_id.as_deref()
}

Check warning on line 31 in src/s3/location.rs

View check run for this annotation

Codecov / codecov/patch

src/s3/location.rs#L29-L31

Added lines #L29 - L31 were not covered by tests

pub(crate) fn join(&self, suffix: &str) -> S3Location {
let mut joined = self.clone();
joined.version_id = None;

Check warning on line 35 in src/s3/location.rs

View check run for this annotation

Codecov / codecov/patch

src/s3/location.rs#L35

Added line #L35 was not covered by tests
if !joined.key.ends_with('/') {
joined.key.push('/');
}
Expand All @@ -34,14 +44,26 @@
S3Location {
bucket: self.bucket.clone(),
key: key.into(),
version_id: None,
}
}

Check warning on line 49 in src/s3/location.rs

View check run for this annotation

Codecov / codecov/patch

src/s3/location.rs#L47-L49

Added lines #L47 - L49 were not covered by tests

pub(crate) fn with_version_id<S: Into<String>>(&self, version_id: S) -> S3Location {
S3Location {
bucket: self.bucket.clone(),
key: self.key.clone(),
version_id: Some(version_id.into()),

Check warning on line 55 in src/s3/location.rs

View check run for this annotation

Codecov / codecov/patch

src/s3/location.rs#L51-L55

Added lines #L51 - L55 were not covered by tests
}
}
}

impl fmt::Display for S3Location {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// TODO: Should the key be percent-encoded?
write!(f, "s3://{}/{}", self.bucket, self.key)
write!(f, "s3://{}/{}", self.bucket, self.key)?;
if let Some(ref v) = self.version_id {
write!(f, "?versionId={v}")?;

Check warning on line 64 in src/s3/location.rs

View check run for this annotation

Codecov / codecov/patch

src/s3/location.rs#L64

Added line #L64 was not covered by tests
}
Ok(())
}
}

Expand All @@ -63,10 +85,10 @@
if bucket.is_empty() || !bucket.chars().all(is_bucket_char) {
return Err(S3LocationError::BadBucket);
}
// TODO: Does the key need to be percent-decoded?
Ok(S3Location {
bucket: bucket.to_owned(),
key: key.to_owned(),
version_id: None,
})
}
}
Expand Down
Loading