@@ -6,10 +6,10 @@ mod tests;
66
77use std:: path:: Path ;
88
9- use anyhow:: { Context , Result , anyhow, bail} ;
9+ use anyhow:: { Context , Error , Result , anyhow, bail} ;
1010use futures_util:: stream:: StreamExt ;
1111use std:: sync:: Arc ;
12- use tokio:: sync:: Semaphore ;
12+ use tokio:: sync:: { Semaphore , mpsc } ;
1313use tracing:: { info, warn} ;
1414use url:: Url ;
1515
@@ -156,8 +156,7 @@ impl Manifestation {
156156 let altered = tmp_cx. dist_server != DEFAULT_DIST_SERVER ;
157157
158158 // Download component packages and validate hashes
159- let mut things_to_install = Vec :: new ( ) ;
160- let mut things_downloaded = Vec :: new ( ) ;
159+ let mut things_downloaded: Vec < String > = Vec :: new ( ) ;
161160 let components = update
162161 . components_urls_and_hashes ( new_manifest)
163162 . map ( |res| {
@@ -168,7 +167,6 @@ impl Manifestation {
168167 } )
169168 } )
170169 . collect :: < Result < Vec < _ > > > ( ) ?;
171-
172170 let components_len = components. len ( ) ;
173171 const DEFAULT_CONCURRENT_DOWNLOADS : usize = 2 ;
174172 let concurrent_downloads = download_cfg
@@ -184,46 +182,15 @@ impl Manifestation {
184182 . and_then ( |s| s. parse ( ) . ok ( ) )
185183 . unwrap_or ( DEFAULT_MAX_RETRIES ) ;
186184
187- info ! ( "downloading component(s)" ) ;
188- let semaphore = Arc :: new ( Semaphore :: new ( concurrent_downloads) ) ;
189- let component_stream = tokio_stream:: iter ( components. into_iter ( ) ) . map ( |bin| {
190- let sem = semaphore. clone ( ) ;
191- async move {
192- let _permit = sem. acquire ( ) . await . unwrap ( ) ;
193- let url = if altered {
194- utils:: parse_url (
195- & bin. binary
196- . url
197- . replace ( DEFAULT_DIST_SERVER , tmp_cx. dist_server . as_str ( ) ) ,
198- ) ?
199- } else {
200- utils:: parse_url ( & bin. binary . url ) ?
201- } ;
202-
203- bin. download ( & url, download_cfg, max_retries, new_manifest)
204- . await
205- . map ( |downloaded| ( bin, downloaded) )
206- }
207- } ) ;
208- if components_len > 0 {
209- let results = component_stream
210- . buffered ( components_len)
211- . collect :: < Vec < _ > > ( )
212- . await ;
213- for result in results {
214- let ( bin, downloaded_file) = result?;
215- things_downloaded. push ( bin. binary . hash . clone ( ) ) ;
216- things_to_install. push ( ( bin, downloaded_file) ) ;
217- }
218- }
219-
220- // Begin transaction
185+ // Begin transaction before the downloads, as installations are interleaved with those
221186 let mut tx = Transaction :: new ( prefix. clone ( ) , tmp_cx, download_cfg. process ) ;
222187
223188 // If the previous installation was from a v1 manifest we need
224189 // to uninstall it first.
225190 tx = self . maybe_handle_v2_upgrade ( & config, tx, download_cfg. process ) ?;
226191
192+ info ! ( "downloading component(s)" ) ;
193+
227194 // Uninstall components
228195 for component in & update. components_to_uninstall {
229196 match ( implicit_modify, & component. target ) {
@@ -255,15 +222,76 @@ impl Manifestation {
255222 tx = self . uninstall_component ( component, new_manifest, tx, download_cfg. process ) ?;
256223 }
257224
258- // Install components
259- for ( component_bin, installer_file) in things_to_install {
260- tx = self . install_component (
261- component_bin,
262- installer_file,
263- download_cfg,
264- new_manifest,
265- tx,
266- ) ?;
225+ if components_len > 0 {
226+ // Create a channel to communicate whenever a download is done and the component can be installed
227+ // The `mpsc` channel was used as we need to send many messages from one producer (download's thread) to one consumer (install's thread)
228+ // This is recommended in the official docs: https://docs.rs/tokio/latest/tokio/sync/index.html#mpsc-channel
229+ let total_components = components. len ( ) ;
230+ let ( download_tx, mut download_rx) =
231+ mpsc:: channel :: < Result < ( ComponentBinary < ' _ > , File ) > > ( total_components) ;
232+
233+ let semaphore = Arc :: new ( Semaphore :: new ( concurrent_downloads) ) ;
234+ let component_stream = tokio_stream:: iter ( components. into_iter ( ) ) . map ( |bin| {
235+ let sem = semaphore. clone ( ) ;
236+ let download_tx = download_tx. clone ( ) ;
237+ async move {
238+ let _permit = sem. acquire ( ) . await . unwrap ( ) ;
239+ let url = if altered {
240+ utils:: parse_url (
241+ & bin. binary
242+ . url
243+ . replace ( DEFAULT_DIST_SERVER , tmp_cx. dist_server . as_str ( ) ) ,
244+ ) ?
245+ } else {
246+ utils:: parse_url ( & bin. binary . url ) ?
247+ } ;
248+
249+ let installer_file = bin
250+ . download ( & url, download_cfg, max_retries, new_manifest)
251+ . await ?;
252+ let hash = bin. binary . hash . clone ( ) ;
253+ let _ = download_tx. send ( Ok ( ( bin, installer_file) ) ) . await ;
254+ Ok ( hash)
255+ }
256+ } ) ;
257+
258+ let mut stream = component_stream. buffered ( components_len) ;
259+ let download_handle = async {
260+ let mut hashes = Vec :: new ( ) ;
261+ while let Some ( result) = stream. next ( ) . await {
262+ match result {
263+ Ok ( hash) => {
264+ hashes. push ( hash) ;
265+ }
266+ Err ( e) => {
267+ let _ = download_tx. send ( Err ( e) ) . await ;
268+ }
269+ }
270+ }
271+ hashes
272+ } ;
273+ let install_handle = async {
274+ let mut current_tx = tx;
275+ let mut counter = 0 ;
276+ while counter < total_components
277+ && let Some ( message) = download_rx. recv ( ) . await
278+ {
279+ let ( component_bin, installer_file) = message?;
280+ current_tx = self . install_component (
281+ component_bin,
282+ installer_file,
283+ download_cfg,
284+ new_manifest,
285+ current_tx,
286+ ) ?;
287+ counter += 1 ;
288+ }
289+ Ok :: < _ , Error > ( current_tx)
290+ } ;
291+
292+ let ( download_results, install_result) = tokio:: join!( download_handle, install_handle) ;
293+ things_downloaded = download_results;
294+ tx = install_result?;
267295 }
268296
269297 // Install new distribution manifest
@@ -759,7 +787,7 @@ impl<'a> ComponentBinary<'a> {
759787 let downloaded_file = RetryIf :: spawn (
760788 FixedInterval :: from_millis ( 0 ) . take ( max_retries) ,
761789 || download_cfg. download ( url, & self . binary . hash , & self . status ) ,
762- |e : & anyhow :: Error | {
790+ |e : & Error | {
763791 // retry only known retriable cases
764792 match e. downcast_ref :: < RustupError > ( ) {
765793 Some ( RustupError :: BrokenPartialFile )
0 commit comments