-
-
Notifications
You must be signed in to change notification settings - Fork 1
Draft: Download manager #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
dbbe919
748106a
c5479b8
9c29890
ad41c23
447be8e
395ced4
1bf7bfd
9adb239
b9911b8
9f25867
c313200
926d2fc
3fe1e92
49d4f60
7973137
31f5218
d96688d
3a14a0d
5566d67
2aac716
9251524
0670468
ce05710
633b752
09b47b6
4803c51
9e9b298
987beaa
c550480
8354cc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
use super::Status; | ||
use crate::Error; | ||
use tokio::{ | ||
fs::File, | ||
sync::{oneshot, watch}, | ||
}; | ||
use tokio_util::sync::CancellationToken; | ||
|
||
#[derive(Debug)] | ||
pub struct DownloadHandle { | ||
result: oneshot::Receiver<Result<File, Error>>, | ||
status: watch::Receiver<Status>, | ||
cancel: CancellationToken, | ||
} | ||
|
||
impl DownloadHandle { | ||
pub fn new( | ||
result: oneshot::Receiver<Result<File, Error>>, | ||
status: watch::Receiver<Status>, | ||
cancel: CancellationToken, | ||
) -> Self { | ||
Self { | ||
result, | ||
status, | ||
cancel, | ||
} | ||
} | ||
} | ||
|
||
impl std::future::Future for DownloadHandle { | ||
type Output = Result<tokio::fs::File, Error>; | ||
|
||
fn poll( | ||
mut self: std::pin::Pin<&mut Self>, | ||
cx: &mut std::task::Context<'_>, | ||
) -> std::task::Poll<Self::Output> { | ||
use std::pin::Pin; | ||
use std::task::Poll; | ||
match Pin::new(&mut self.result).poll(cx) { | ||
Poll::Ready(Ok(result)) => Poll::Ready(result), | ||
Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Oneshot(e))), | ||
Poll::Pending => Poll::Pending, | ||
} | ||
} | ||
} | ||
|
||
impl DownloadHandle { | ||
pub fn status(&self) -> Status { | ||
*self.status.borrow() | ||
} | ||
|
||
pub async fn wait_for_status_update(&mut self) -> Result<(), watch::error::RecvError> { | ||
self.status.changed().await | ||
} | ||
|
||
pub fn cancel(&self) { | ||
self.cancel.cancel(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
use reqwest::{Client, Url}; | ||
use std::{path::PathBuf, sync::Arc}; | ||
use tokio::sync::{mpsc, oneshot, watch, Semaphore}; | ||
use tokio_util::sync::CancellationToken; | ||
|
||
use super::{download_thread, DownloadHandle, DownloadRequest, Status}; | ||
|
||
const QUEUE_SIZE: usize = 100; | ||
|
||
#[derive(Debug)] | ||
pub struct DownloadManager { | ||
queue: mpsc::Sender<DownloadRequest>, | ||
semaphore: Arc<Semaphore>, | ||
cancel: CancellationToken, | ||
} | ||
|
||
impl Drop for DownloadManager { | ||
fn drop(&mut self) { | ||
// Need to manually close the semaphore to make sure dispatcher_thread stops waiting for permits | ||
self.semaphore.close(); | ||
} | ||
} | ||
|
||
impl DownloadManager { | ||
pub fn new(limit: usize) -> Self { | ||
let (tx, rx) = mpsc::channel(QUEUE_SIZE); | ||
let client = Client::new(); | ||
let semaphore = Arc::new(Semaphore::new(limit)); | ||
let manager = Self { | ||
queue: tx, | ||
semaphore: semaphore.clone(), | ||
cancel: CancellationToken::new(), | ||
}; | ||
// Spawn the dispatcher thread to handle download requests | ||
tokio::spawn(async move { dispatcher_thread(client, rx, semaphore).await }); | ||
manager | ||
} | ||
|
||
pub fn set_max_parallel_downloads(&self, limit: usize) { | ||
let current = self.semaphore.available_permits(); | ||
if limit > current { | ||
self.semaphore.add_permits(limit - current); | ||
} else if limit < current { | ||
let to_remove = current - limit; | ||
for _ in 0..to_remove { | ||
let _ = self.semaphore.try_acquire(); | ||
} | ||
} | ||
} | ||
|
||
pub fn add_request(&self, url: Url, destination: PathBuf) -> DownloadHandle { | ||
let (result_tx, result_rx) = oneshot::channel(); | ||
let (status_tx, status_rx) = watch::channel(Status::Queued); | ||
let cancel = self.cancel.child_token(); | ||
|
||
let req = DownloadRequest { | ||
url, | ||
destination, | ||
result: result_tx, | ||
status: status_tx, | ||
cancel: cancel.clone(), | ||
}; | ||
|
||
let _ = self.queue.try_send(req); | ||
|
||
DownloadHandle::new(result_rx, status_rx, cancel) | ||
} | ||
|
||
pub fn cancel_all(&self) { | ||
self.cancel.cancel(); | ||
} | ||
} | ||
|
||
async fn dispatcher_thread( | ||
client: Client, | ||
mut rx: mpsc::Receiver<DownloadRequest>, | ||
sem: Arc<Semaphore>, | ||
) { | ||
while let Some(request) = rx.recv().await { | ||
let permit = match sem.clone().acquire_owned().await { | ||
Ok(permit) => permit, | ||
Err(_) => break, | ||
}; | ||
let client = client.clone(); | ||
tokio::spawn(async move { | ||
// Move the permit into the worker thread so it's automatically released when the thread finishes | ||
let _permit = permit; | ||
download_thread(client.clone(), request).await; | ||
}); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
mod handle; | ||
mod manager; | ||
mod types; | ||
mod worker; | ||
|
||
pub use handle::DownloadHandle; | ||
pub use manager::*; | ||
pub use types::*; | ||
pub(self) use worker::download_thread; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
use crate::Error; | ||
use reqwest::Url; | ||
use std::path::{Path, PathBuf}; | ||
use tokio::{ | ||
fs::File, | ||
sync::{oneshot, watch}, | ||
}; | ||
use tokio_util::sync::CancellationToken; | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct DownloadRequest { | ||
pub url: Url, | ||
pub destination: PathBuf, | ||
pub result: oneshot::Sender<Result<File, Error>>, | ||
pub status: watch::Sender<Status>, | ||
pub cancel: CancellationToken, | ||
} | ||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||
pub struct DownloadProgress { | ||
pub bytes_downloaded: u64, | ||
pub total_bytes: Option<u64>, | ||
} | ||
|
||
#[derive(Debug, Copy, Clone, PartialEq, Eq)] | ||
pub enum Status { | ||
Queued, | ||
InProgress(DownloadProgress), | ||
Retrying, | ||
Completed, | ||
Failed, | ||
Cancelled, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
use super::{DownloadProgress, DownloadRequest}; | ||
use crate::{downloader::Status, error::DownloadError, Error}; | ||
use reqwest::Client; | ||
use std::time::Duration; | ||
use tokio::{fs::File, io::AsyncWriteExt}; | ||
|
||
const MAX_RETRIES: usize = 3; | ||
|
||
pub(super) async fn download_thread(client: Client, mut req: DownloadRequest) { | ||
fn should_retry(e: &Error) -> bool { | ||
match e { | ||
Error::Reqwest(network_err) => { | ||
network_err.is_timeout() | ||
|| network_err.is_connect() | ||
|| network_err.is_request() | ||
|| network_err | ||
.status() | ||
.map(|status_code| status_code.is_server_error()) | ||
.unwrap_or(true) | ||
} | ||
Error::Download(DownloadError::Cancelled) | Error::Io(_) => false, | ||
_ => false, | ||
} | ||
} | ||
|
||
let mut last_error = None; | ||
for attempt in 0..=(MAX_RETRIES + 1) { | ||
if attempt > MAX_RETRIES { | ||
req.status.send(Status::Failed).ok(); | ||
req.result | ||
.send(Err(Error::Download(DownloadError::RetriesExhausted { | ||
last_error_msg: last_error | ||
.as_ref() | ||
.map(ToString::to_string) | ||
.unwrap_or_else(|| "Unknown Error".to_string()), | ||
}))) | ||
.ok(); | ||
return; | ||
} | ||
|
||
if attempt > 0 { | ||
req.status.send(Status::Retrying).ok(); | ||
// Basic exponential backoff | ||
let delay_ms = 1000 * 2u64.pow(attempt as u32 - 1); | ||
let delay = Duration::from_millis(delay_ms); | ||
|
||
tokio::select! { | ||
_ = tokio::time::sleep(delay) => {}, | ||
_ = req.cancel.cancelled() => { | ||
req.status.send(Status::Failed).ok(); | ||
req.result.send(Err(Error::Download(DownloadError::Cancelled))).ok(); | ||
return; | ||
} | ||
} | ||
} | ||
|
||
match download(client.clone(), &mut req).await { | ||
Ok(file) => { | ||
req.status.send(Status::Completed).ok(); | ||
req.result.send(Ok(file)).ok(); | ||
return; | ||
} | ||
Err(e) => { | ||
if should_retry(&e) { | ||
last_error = Some(e); | ||
continue; | ||
} | ||
|
||
let status = if matches!(e, Error::Download(DownloadError::Cancelled)) { | ||
Status::Cancelled | ||
} else { | ||
Status::Failed | ||
}; | ||
req.status.send(status).ok(); | ||
req.result.send(Err(e)).ok(); | ||
return; | ||
} | ||
} | ||
} | ||
} | ||
|
||
async fn download(client: Client, req: &mut DownloadRequest) -> Result<File, Error> { | ||
let update_progress = |bytes_downloaded: u64, total_bytes: Option<u64>| { | ||
req.status | ||
.send(Status::InProgress(DownloadProgress { | ||
bytes_downloaded, | ||
total_bytes, | ||
})) | ||
.ok(); | ||
}; | ||
|
||
let mut response = client | ||
.get(req.url.as_ref()) | ||
.send() | ||
.await? | ||
.error_for_status()?; | ||
let total_bytes = response.content_length(); | ||
let mut bytes_downloaded = 0u64; | ||
|
||
// Create the destination directory if it doesn't exist | ||
if let Some(parent) = req.destination.parent() { | ||
tokio::fs::create_dir_all(parent).await?; | ||
} | ||
let mut file = File::create(&req.destination).await?; | ||
|
||
update_progress(bytes_downloaded, total_bytes); | ||
loop { | ||
tokio::select! { | ||
_ = req.cancel.cancelled() => { | ||
drop(file); // Manually drop the file handle to ensure that deletion doesn't fail | ||
tokio::fs::remove_file(&req.destination).await?; | ||
return Err(Error::Download(DownloadError::Cancelled)); | ||
} | ||
chunk = response.chunk() => { | ||
match chunk { | ||
Ok(Some(chunk)) => { | ||
file.write_all(&chunk).await?; | ||
bytes_downloaded += chunk.len() as u64; | ||
update_progress(bytes_downloaded, total_bytes); | ||
} | ||
Ok(None) => break, | ||
Err(e) => { | ||
drop(file); // Manually drop the file handle to ensure that deletion doesn't fail | ||
tokio::fs::remove_file(&req.destination).await?; | ||
return Err(Error::Reqwest(e)) | ||
}, | ||
} | ||
} | ||
} | ||
} | ||
mirkobrombin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
update_progress(bytes_downloaded, total_bytes); | ||
|
||
// Ensure the data is written to disk | ||
file.sync_all().await?; | ||
// Open a new file handle with RO permissions | ||
let file = File::options().read(true).open(&req.destination).await?; | ||
Ok(file) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,18 @@ pub enum Error { | |
Io(#[from] std::io::Error), | ||
#[error("Serde: {0}")] | ||
Serde(#[from] serde_json::Error), | ||
#[error("Reqwest: {0}")] | ||
Reqwest(#[from] reqwest::Error), | ||
#[error("Oneshot: {0}")] | ||
Oneshot(#[from] tokio::sync::oneshot::error::RecvError), | ||
Comment on lines
+12
to
+13
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like having to add this specific error type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, not sure, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can add |
||
#[error("Download: {0}")] | ||
Download(#[from] DownloadError), | ||
} | ||
|
||
#[derive(Error, Debug, Clone)] | ||
pub enum DownloadError { | ||
#[error("Download was cancelled")] | ||
Cancelled, | ||
#[error("Retry limit exceeded")] | ||
RetriesExhausted { last_error_msg: String }, | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
pub mod downloader; | ||
mod error; | ||
pub mod runner; | ||
|
||
pub use error::Error; | ||
|
||
pub mod proto { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe cap the exponential backoff and sprinkle in some jitter so retries don’t wait forever.