use async_trait::async_trait; use crossterm::style::Stylize; use futures::{stream, StreamExt}; use reqwest::{Client, StatusCode}; use std::{path::PathBuf, sync::Arc}; use thiserror::Error; use tokio::{fs::File, io::AsyncWriteExt}; use url::Url; /// A struct that can download multiple files over http parallely using tokio /// and reqwest pub struct Downloader { /// A callback object defined by the user containing methods that will be /// called when progress is made or the process completes pub callback: C, /// The files the downloader will download pub files: Vec, /// The maximum amount of files the downloader will download at once pub parellel_count: usize, /// The HTTP client to use to download files pub client: Arc, } impl Downloader { /// Starts the downloading process and returns the value that is returned /// from the callback in the on_completed method pub async fn download(self) -> C::EndRes { let Self { mut callback, files, parellel_count: parallel_count, client, } = self; let it = files .into_iter() .map(|f| Self::download_one(Arc::clone(&client), f.url, f.target)); let mut stream = stream::iter(it).buffer_unordered(parallel_count); let mut stop_info = None; while let Some(res) = stream.next().await { match callback.on_download_complete(res).await { CallbackStatus::Stop(i) => { stop_info = Some(i); break; }, CallbackStatus::Continue => {}, } } callback.on_completed(stop_info).await } /// Downloads a single file. Used intenally by the Downloader async fn download_one( client: Arc, url: Url, target: PathBuf, ) -> Result { if let Some(parent) = target.parent() { tokio::fs::create_dir_all(parent).await?; } let mut file = File::create(&target).await?; let res = client.get(url.clone()).send().await?; let status = res.status(); let mut stream = res.bytes_stream(); if let Some(b) = stream.next().await { file.write_all_buf(&mut b?).await?; } Ok(DownloadInfo { from: url, to: target, status, }) } } /// A file to be downloaded by a Downloader pub struct FileToDownload { /// The url to HTTP GET the file from pub url: Url, /// The file to save to pub target: PathBuf, } /// A struct containing information about a file that is passed to the callback /// once it has finished downloading pub struct DownloadInfo { /// the URL the file has been downloaded from pub from: Url, /// The path the file has been saved to pub to: PathBuf, /// The HTTP status code returned by the server pub status: StatusCode, } impl DownloadInfo { pub fn to_colored_text(&self) -> String { format!( "{} {} => {}", self.status.as_str().red(), self.from.as_str().cyan().bold(), self.to.to_string_lossy().cyan().bold() ) } } /// An error that can occur while a file is being downloaded #[derive(Debug, Error)] pub enum DownloadError { #[error("HTTP Error: {0}")] HttpError(#[from] reqwest::Error), #[error("Filesystem error: {0}")] FilesystemError(#[from] std::io::Error), } /// Returned by the on_download_complete method in the callback to either /// continue or stop the downloader. If the downloader is stopped, a StopInfo /// object is also sent which will be received by the on_completed function if /// the download was stopped. pub enum CallbackStatus { Stop(I), Continue, } /// A callback driven by a Downloader #[async_trait] pub trait Callback { /// The type returned by the file downloader from the on_completed callback /// function. type EndRes; /// The type sent to the on_completed function when the downloader was /// interrupted by the on_download_complete function. type StopInfo; /// Called by the downloader once a file has been downloaded or failed to /// download /// /// * `res` - The result of the file download operation async fn on_download_complete( &mut self, res: Result, ) -> CallbackStatus; /// Called by the Downloader once it has completed. The return value of this /// function will also be returned by the file downloader. /// /// * `stop_info` - Normally `None`, unless the on_download_complete /// function stopped the /// downloader, in which case the value returned by it will be provided. async fn on_completed(self, stop_info: Option) -> Self::EndRes; }