use super::create_http_client_async_with_url; use hbb_common::{ bail, lazy_static::lazy_static, log, tokio::{ self, fs::File, io::AsyncWriteExt, sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, }, ResultType, }; use serde_derive::Serialize; use std::{collections::HashMap, path::PathBuf, sync::Mutex, time::Duration}; lazy_static! { static ref DOWNLOADERS: Mutex> = Default::default(); } /// This struct is used to return the download data to the caller. /// The caller should check if the file is downloaded successfully and remove the job from the map. /// If the file is not downloaded successfully, the `data` field will be empty. /// If the file is downloaded successfully, the `data` field will contain the downloaded data if `path` is None. #[derive(Serialize, Debug)] pub struct DownloadData { #[serde(skip_serializing_if = "Vec::is_empty")] pub data: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub path: Option, #[serde(skip_serializing_if = "Option::is_none")] pub total_size: Option, pub downloaded_size: u64, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } struct Downloader { data: Vec, path: Option, // Some file may be empty, so we use Option to indicate if the size is known total_size: Option, downloaded_size: u64, error: Option, finished: bool, tx_cancel: UnboundedSender<()>, } // The caller should check if the file is downloaded successfully and remove the job from the map. pub fn download_file( url: String, path: Option, auto_del_dur: Option, ) -> ResultType { let id = url.clone(); // First pass: if a non-error downloader exists for this URL, reuse it. // If an errored downloader exists, remove it so this call can retry. let mut stale_path = None; { let mut downloaders = DOWNLOADERS.lock().unwrap(); if let Some(downloader) = downloaders.get(&id) { if downloader.error.is_none() { return Ok(id); } stale_path = downloader.path.clone(); downloaders.remove(&id); } } if let Some(p) = stale_path { if p.exists() { if let Err(e) = std::fs::remove_file(&p) { log::warn!("Failed to remove stale download file {}: {}", p.display(), e); } } } if let Some(path) = path.as_ref() { if path.exists() { bail!("File {} already exists", path.display()); } if let Some(parent) = path.parent() { std::fs::create_dir_all(parent)?; } } let (tx, rx) = unbounded_channel(); let downloader = Downloader { data: Vec::new(), path: path.clone(), total_size: None, downloaded_size: 0, error: None, tx_cancel: tx, finished: false, }; // Second pass (atomic with insert) to avoid race with another concurrent caller. let mut stale_path_after_check = None; { let mut downloaders = DOWNLOADERS.lock().unwrap(); if let Some(existing) = downloaders.get(&id) { if existing.error.is_none() { return Ok(id); } stale_path_after_check = existing.path.clone(); downloaders.remove(&id); } downloaders.insert(id.clone(), downloader); } if let Some(p) = stale_path_after_check { if p.exists() { if let Err(e) = std::fs::remove_file(&p) { log::warn!("Failed to remove stale download file {}: {}", p.display(), e); } } } let id2 = id.clone(); std::thread::spawn( move || match do_download(&id2, url, path, auto_del_dur, rx) { Ok(is_all_downloaded) => { let mut downloaded_size = 0; let mut total_size = 0; DOWNLOADERS.lock().unwrap().get_mut(&id2).map(|downloader| { downloaded_size = downloader.downloaded_size; total_size = downloader.total_size.unwrap_or(0); }); log::info!( "Download {} end, {}/{}, {:.2} %", &id2, downloaded_size, total_size, if total_size == 0 { 0.0 } else { downloaded_size as f64 / total_size as f64 * 100.0 } ); let is_canceled = !is_all_downloaded; if is_canceled { if let Some(downloader) = DOWNLOADERS.lock().unwrap().remove(&id2) { if let Some(p) = downloader.path { if p.exists() { std::fs::remove_file(p).ok(); } } } } } Err(e) => { let err = e.to_string(); log::error!("Download {}, failed: {}", &id2, &err); DOWNLOADERS.lock().unwrap().get_mut(&id2).map(|downloader| { downloader.error = Some(err); }); } }, ); Ok(id) } #[tokio::main(flavor = "current_thread")] async fn do_download( id: &str, url: String, path: Option, auto_del_dur: Option, mut rx_cancel: UnboundedReceiver<()>, ) -> ResultType { let client = create_http_client_async_with_url(&url).await; let mut is_all_downloaded = false; tokio::select! { _ = rx_cancel.recv() => { return Ok(is_all_downloaded); } head_resp = client.head(&url).send() => { match head_resp { Ok(resp) => { if resp.status().is_success() { let total_size = resp .headers() .get(reqwest::header::CONTENT_LENGTH) .and_then(|ct_len| ct_len.to_str().ok()) .and_then(|ct_len| ct_len.parse::().ok()); let Some(total_size) = total_size else { bail!("Failed to get content length"); }; DOWNLOADERS.lock().unwrap().get_mut(id).map(|downloader| { downloader.total_size = Some(total_size); }); } else { bail!("Failed to get content length: {}", resp.status()); } } Err(e) => { return Err(e.into()); } } } } let mut response; tokio::select! { _ = rx_cancel.recv() => { return Ok(is_all_downloaded); } resp = client.get(url).send() => { response = resp?; } } let mut dest: Option = None; if let Some(p) = path { dest = Some(File::create(p).await?); } loop { tokio::select! { _ = rx_cancel.recv() => { break; } chunk = response.chunk() => { match chunk { Ok(Some(chunk)) => { match dest { Some(ref mut f) => { f.write_all(&chunk).await?; f.flush().await?; DOWNLOADERS.lock().unwrap().get_mut(id).map(|downloader| { downloader.downloaded_size += chunk.len() as u64; }); } None => { DOWNLOADERS.lock().unwrap().get_mut(id).map(|downloader| { downloader.data.extend_from_slice(&chunk); downloader.downloaded_size += chunk.len() as u64; }); } } } Ok(None) => { is_all_downloaded = true; break; }, Err(e) => { log::error!("Download {} failed: {}", id, e); return Err(e.into()); } } } } } if let Some(mut f) = dest.take() { f.flush().await?; } if let Some(ref mut downloader) = DOWNLOADERS.lock().unwrap().get_mut(id) { downloader.finished = true; } if is_all_downloaded { let id_del = id.to_string(); if let Some(dur) = auto_del_dur { tokio::spawn(async move { tokio::time::sleep(dur).await; DOWNLOADERS.lock().unwrap().remove(&id_del); }); } } Ok(is_all_downloaded) } pub fn get_download_data(id: &str) -> ResultType { let downloaders = DOWNLOADERS.lock().unwrap(); if let Some(downloader) = downloaders.get(id) { let downloaded_size = downloader.downloaded_size; let total_size = downloader.total_size.clone(); let error = downloader.error.clone(); let data = if total_size.unwrap_or(0) == downloaded_size && downloader.path.is_none() { downloader.data.clone() } else { Vec::new() }; let path = downloader.path.clone(); let download_data = DownloadData { data, path, total_size, downloaded_size, error, }; Ok(download_data) } else { bail!("Downloader not found") } } pub fn cancel(id: &str) { if let Some(downloader) = DOWNLOADERS.lock().unwrap().get(id) { // downloader.is_canceled.store(true, Ordering::SeqCst); // The receiver may not be able to receive the cancel signal, so we also set the atomic bool to true let _ = downloader.tx_cancel.send(()); } } pub fn remove(id: &str) { let _ = DOWNLOADERS.lock().unwrap().remove(id); }