@@ -40,6 +40,11 @@ def collect_gradients(
4040 if batches is None :
4141 batches = [[idx ] for idx in range (len (data ))]
4242
43+ print (
44+ f"Rank { rank } has { len (batches )} batches and thinks world "
45+ f"size is { dist .get_world_size ()} ."
46+ )
47+
4348 # Mutable state for the GradientCollector callback
4449 mod_grads = {}
4550 preconditioners = processor .preconditioners
@@ -49,22 +54,18 @@ def collect_gradients(
4954 lo = torch .finfo (dtype ).min
5055 hi = torch .finfo (dtype ).max
5156
57+ owned_modules : set [str ] = set ()
58+ module_to_rank : dict [str , int ] = {}
59+
5260 def callback (name : str , g : torch .Tensor ):
5361 g = g .flatten (1 ).clamp_ (lo , hi )
54- if save_index :
55- # Asynchronously move the gradient to CPU and convert to the final dtype
56- mod_grads [name ] = g .to (device = "cpu" , dtype = dtype , non_blocking = True )
57- else :
58- mod_grads [name ] = g .to (dtype = dtype )
59-
60- # Compute the outer product of the flattened gradient
61- if not cfg .skip_preconditioners :
62- g = g .float ()
63- preconditioner = preconditioners .get (name , None )
64- if preconditioner is None :
65- preconditioners [name ] = g .mT @ g
62+ # Keep gradients in original dtype for preconditioner computation
63+ mod_grads [name ] = g
64+ if cfg .skip_preconditioners :
65+ if save_index :
66+ mod_grads [name ] = g .to (dtype = dtype , device = "cpu" , non_blocking = True )
6667 else :
67- preconditioner . addmm_ ( g . mT , g )
68+ mod_grads [ name ] = g . to ( dtype = dtype )
6869
6970 collector = GradientCollector (
7071 model .base_model ,
@@ -74,6 +75,33 @@ def callback(name: str, g: torch.Tensor):
7475 attention_cfgs = attention_cfgs or {},
7576 )
7677
78+ # Determine which modules this rank owns for preconditioner computation
79+ if dist .is_initialized ():
80+ num_devices = dist .get_world_size ()
81+ # This list is sorted.
82+ available_modules = list (collector .shapes ().keys ())
83+
84+ num_modules = len (available_modules )
85+ base , remainder = divmod (num_modules , num_devices )
86+
87+ assert base > 0 , "Each rank must own at least one module"
88+
89+ start_idx = rank * base + min (rank , remainder )
90+ end_idx = start_idx + base + (1 if rank < remainder else 0 )
91+ owned_modules = set (available_modules [start_idx :end_idx ])
92+
93+ for i , module_name in enumerate (available_modules ):
94+ # Inverse of the start_idx formula
95+ module_to_rank [module_name ] = (
96+ min (i // (base + 1 ), remainder - 1 )
97+ if i < remainder * (base + 1 )
98+ else remainder + (i - remainder * (base + 1 )) // base
99+ )
100+
101+ print (f"Rank { rank } owns { len (owned_modules )} modules" )
102+ else :
103+ owned_modules = set (collector .shapes ().keys ())
104+
77105 # Allocate space ahead of time for the gradients
78106 grad_sizes = {name : math .prod (s ) for name , s in collector .shapes ().items ()}
79107 builder = (
@@ -89,7 +117,8 @@ def callback(name: str, g: torch.Tensor):
89117 fill_value = 0.0 ,
90118 )
91119
92- for indices in tqdm (batches , disable = rank != 0 , desc = "Building index" ):
120+ # rank != 0
121+ for indices in tqdm (batches , disable = False , desc = "Building index" ):
93122 batch = data [indices ]
94123 x , y = pad_and_tensor (
95124 batch ["input_ids" ], # type: ignore
@@ -132,6 +161,22 @@ def callback(name: str, g: torch.Tensor):
132161
133162 model .zero_grad ()
134163
164+ # Send gradients to owning ranks and compute outer products there
165+ if not cfg .skip_preconditioners :
166+ exchange_preconditioner_gradients (
167+ mod_grads , preconditioners , module_to_rank , owned_modules , rank
168+ )
169+
170+ # Convert mod_grads to the right dtype for save_index logic
171+ if save_index :
172+ for name in mod_grads :
173+ mod_grads [name ] = mod_grads [name ].to (
174+ device = "cpu" , dtype = dtype , non_blocking = True
175+ )
176+ else :
177+ for name in mod_grads :
178+ mod_grads [name ] = mod_grads [name ].to (dtype = dtype )
179+
135180 if builder is not None :
136181 builder (indices , mod_grads )
137182
@@ -141,7 +186,8 @@ def callback(name: str, g: torch.Tensor):
141186 mod_grads .clear ()
142187 per_doc_losses [indices ] = losses .detach ().type_as (per_doc_losses )
143188
144- process_preconditioners (processor , preconditioners , len (data ))
189+ if not cfg .skip_preconditioners :
190+ process_preconditioners (processor , preconditioners , len (data ), grad_sizes , rank )
145191
146192 if dist .is_initialized ():
147193 dist .reduce (per_doc_losses , dst = 0 )
@@ -266,58 +312,175 @@ def dist_reduce(self):
266312 self .in_memory_grad_buffer .cpu ().numpy ().astype (self .grad_buffer .dtype )
267313 )
268314
315+ self .in_memory_grad_buffer = self .in_memory_grad_buffer .cpu ()
316+
317+
318+ def exchange_preconditioner_gradients (
319+ mod_grads : dict [str , torch .Tensor ],
320+ preconditioners : dict [str , torch .Tensor ],
321+ module_to_rank : dict [str , int ],
322+ owned_modules : set [str ],
323+ rank : int ,
324+ ):
325+ """
326+ Send gradients to the ranks that own their preconditioners, and accumulate
327+ outer products on the owning ranks.
328+ Each rank sends gradients for modules it doesn't own to the owning ranks,
329+ and receives gradients for modules it owns to compute outer products.
330+ """
331+ # Process current rank data for owned modules
332+ for name , g in mod_grads .items ():
333+ if name not in owned_modules :
334+ continue
335+
336+ g = g .float ()
337+ if name in preconditioners :
338+ preconditioners [name ].addmm_ (g .mT , g )
339+ else :
340+ preconditioners [name ] = g .mT @ g
341+
342+ if not dist .is_initialized ():
343+ return
344+
345+ world_size = dist .get_world_size ()
346+ device = next (iter (mod_grads .values ())).device
347+
348+ module_names = list (mod_grads .keys ())
349+ module_numel = {n : int (mod_grads [n ].numel ()) for n in module_names }
350+
351+ current_rank_chunk = torch .empty (0 , device = device , dtype = torch .float32 )
352+
353+ # Flatten batch dimension: all to all works on contiguous 1-D tensors
354+ send_chunks = [
355+ (
356+ current_rank_chunk
357+ if dest == rank
358+ else torch .cat (
359+ [
360+ mod_grads [name ].flatten ()
361+ for name in module_names
362+ if module_to_rank [name ] == dest
363+ ]
364+ )
365+ )
366+ for dest in range (world_size )
367+ ]
368+
369+ # --- collective exchange of gradient sizes in order of mod_grads ---
370+ send_sizes = torch .tensor (
371+ [t .numel () for t in send_chunks ], device = device , dtype = torch .int64
372+ )
373+ recv_sizes = torch .empty_like (send_sizes )
374+
375+ dist .all_to_all_single (recv_sizes , send_sizes )
376+
377+ # --- collective exchange of gradient in order of mod_grads ---
378+ send_buf = torch .cat (send_chunks )
379+ recv_buf = torch .empty (
380+ int (recv_sizes .sum ().item ()), device = device , dtype = torch .float32
381+ )
382+
383+ dist .all_to_all_single (
384+ recv_buf ,
385+ send_buf ,
386+ output_split_sizes = recv_sizes .tolist (),
387+ input_split_sizes = send_sizes .tolist (),
388+ )
389+
390+ # Unpack gradients in src-rank order
391+ # Within each src partition, modules are in fixed order.
392+ offset = 0
393+ for src_rank in range (world_size ):
394+ part_len = int (recv_sizes [src_rank ].item ())
395+ part = recv_buf [offset : offset + part_len ]
396+ offset += part_len
397+
398+ if part_len == 0 or src_rank == rank :
399+ continue
400+
401+ p = 0
402+ for name in owned_modules :
403+ n = module_numel [name ]
404+ flat = part [p : p + n ]
405+ p += n
406+
407+ feature_dim = mod_grads [name ].shape [- 1 ]
408+ g = flat .to (device , non_blocking = True ).view (- 1 , feature_dim ).float ()
409+
410+ if name in preconditioners :
411+ preconditioners [name ].addmm_ (g .mT , g )
412+ else :
413+ preconditioners [name ] = g .mT @ g
414+
269415
270416def process_preconditioners (
271417 processor : GradientProcessor ,
272418 preconditioners : dict [str , torch .Tensor ],
273419 len_data : int ,
420+ grad_sizes : dict [str , int ],
421+ rank : int ,
274422):
275423 """
276424 Aggregate preconditioners across ranks and compute their eigen decomposition
277425 distributed across all ranks.
278426 """
279-
280- rank = dist .get_rank () if dist .is_initialized () else 0
281- world_size = dist .get_world_size () if dist .is_initialized () else 1
282427 preconditioners_eigen = {}
428+
429+ device = next (iter (preconditioners .values ())).device
430+ dtype = next (iter (preconditioners .values ())).dtype
431+
283432 if rank == 0 :
284433 print ("Saving preconditioners..." )
285- for name , prec in preconditioners .items ():
286- if dist .is_initialized ():
287- dist .all_reduce (prec )
288434
289- preconditioners [name ] = prec / len_data
290-
291- processor .preconditioners = preconditioners
435+ for name , prec in preconditioners .items ():
436+ preconditioners [name ] = (prec / len_data ).cpu ()
292437
293438 if rank == 0 :
294439 print ("Computing preconditioner eigen decompositions..." )
295- names = list (preconditioners .keys ())
296- names_per_rank = names [rank ::world_size ]
297440
298- for name in names_per_rank :
299- original_dtype = preconditioners [name ].dtype
300- prec = preconditioners [name ].to (dtype = torch .float64 )
441+ for name in preconditioners .keys ():
442+ prec = preconditioners [name ].to (dtype = torch .float64 , device = device )
301443 eigvals , eigvecs = torch .linalg .eigh (prec )
302444 preconditioners_eigen [name ] = (
303- eigvals .to (dtype = original_dtype ).contiguous (),
304- eigvecs .to (dtype = original_dtype ).contiguous (),
445+ eigvals .to (dtype = dtype ).contiguous (). cpu (),
446+ eigvecs .to (dtype = dtype ).contiguous (). cpu (),
305447 )
306448
307449 if rank == 0 :
308- print ("Gathering and saving preconditioner eigen decompositions..." )
450+ print ("Gathering preconditioners..." )
451+
452+ cpu_group = dist .new_group (backend = "gloo" )
453+
454+ for name , grad_size in grad_sizes .items ():
455+ if name in preconditioners :
456+ local_prec = preconditioners [name ]
457+ del preconditioners [name ]
458+ else :
459+ local_prec = torch .zeros ([grad_size , grad_size ], dtype = dtype , device = "cpu" )
460+
461+ dist .reduce (local_prec , dst = 0 , op = dist .ReduceOp .SUM , group = cpu_group )
309462
310- for name in names :
311- prec = preconditioners [name ]
463+ if rank == 0 :
464+ preconditioners [name ] = local_prec
465+
466+ if rank == 0 :
467+ processor .preconditioners = preconditioners
468+
469+ print ("Gathering eigen decompositions..." )
470+
471+ for name , grad_size in grad_sizes .items ():
472+ prec_size = torch .Size ([grad_size , grad_size ])
312473 if name not in preconditioners_eigen :
313- eigval = torch .zeros (prec . size ( 0 ) , dtype = prec . dtype , device = prec . device )
314- eigvec = torch .zeros_like ( prec )
474+ eigval = torch .zeros (prec_size [ 0 ] , dtype = dtype )
475+ eigvec = torch .zeros ( prec_size , dtype = dtype )
315476 else :
316477 eigval , eigvec = preconditioners_eigen [name ]
317478
318- dist .all_reduce (eigval , op = dist .ReduceOp .SUM ) if dist .is_initialized () else None
319- dist .all_reduce (eigvec , op = dist .ReduceOp .SUM ) if dist .is_initialized () else None
479+ dist .reduce (eigval , dst = 0 , op = dist .ReduceOp .SUM , group = cpu_group )
480+ dist .reduce (eigvec , dst = 0 , op = dist .ReduceOp .SUM , group = cpu_group )
481+
482+ if rank == 0 :
483+ preconditioners_eigen [name ] = (eigval , eigvec )
320484
321- preconditioners_eigen [name ] = (eigval , eigvec )
322485 if rank == 0 :
323486 processor .preconditioners_eigen = preconditioners_eigen
0 commit comments