Skip to content

Commit 063b198

Browse files
committed
Enable FSDP across nodes with START_RANK
1 parent 133ec0e commit 063b198

File tree

8 files changed

+253
-86
lines changed

8 files changed

+253
-86
lines changed

bergson/__main__.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import shutil
23
from dataclasses import dataclass
34
from pathlib import Path
@@ -14,16 +15,18 @@
1415

1516
def validate_run_path(run_path: Path, overwrite: bool):
1617
"""Validate the run path."""
17-
if overwrite:
18+
start_rank = int(os.environ.get("START_RANK", 0))
19+
rank = start_rank + int(os.environ.get("RANK", 0))
20+
21+
if rank != 0 or not run_path.exists():
1822
return
1923

20-
if run_path.exists():
21-
print(f"Run path {run_path} already exists.")
22-
response = input("Do you want to overwrite the existing run path? (y/n): ")
23-
if response.lower() != "y":
24-
exit()
25-
else:
26-
shutil.rmtree(run_path)
24+
if overwrite:
25+
shutil.rmtree(run_path)
26+
else:
27+
raise FileExistsError(
28+
f"Run path {run_path} already exists. Use --overwrite to overwrite it."
29+
)
2730

2831

2932
@dataclass
@@ -55,14 +58,9 @@ class Reduce:
5558

5659
def execute(self):
5760
"""Reduce a gradient index."""
58-
if self.index_cfg.projection_dim != 0:
59-
print(
60-
"Warning: projection_dim is not 0. "
61-
"Compressed gradients will be reduced."
62-
)
63-
6461
run_path = Path(self.index_cfg.run_path)
6562
partial_run_path = Path(self.index_cfg.partial_run_path)
63+
6664
validate_run_path(run_path, self.index_cfg.overwrite)
6765
validate_run_path(partial_run_path, self.index_cfg.overwrite)
6866

bergson/build.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def flush(kwargs):
9999
if rank == 0:
100100
processor.save(cfg.partial_run_path)
101101

102+
if dist.is_initialized():
103+
dist.barrier()
104+
102105

103106
def build(index_cfg: IndexConfig):
104107
"""
@@ -118,4 +121,6 @@ def build(index_cfg: IndexConfig):
118121

119122
launch_distributed_run("build", build_worker, [index_cfg, ds])
120123

121-
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)
124+
rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
125+
if rank == 0:
126+
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)

bergson/collection.py

Lines changed: 202 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def collect_gradients(
4040
if batches is None:
4141
batches = [[idx] for idx in range(len(data))]
4242

43+
print(
44+
f"Rank {rank} has {len(batches)} batches and thinks world "
45+
f"size is {dist.get_world_size()}."
46+
)
47+
4348
# Mutable state for the GradientCollector callback
4449
mod_grads = {}
4550
preconditioners = processor.preconditioners
@@ -49,22 +54,18 @@ def collect_gradients(
4954
lo = torch.finfo(dtype).min
5055
hi = torch.finfo(dtype).max
5156

57+
owned_modules: set[str] = set()
58+
module_to_rank: dict[str, int] = {}
59+
5260
def callback(name: str, g: torch.Tensor):
5361
g = g.flatten(1).clamp_(lo, hi)
54-
if save_index:
55-
# Asynchronously move the gradient to CPU and convert to the final dtype
56-
mod_grads[name] = g.to(device="cpu", dtype=dtype, non_blocking=True)
57-
else:
58-
mod_grads[name] = g.to(dtype=dtype)
59-
60-
# Compute the outer product of the flattened gradient
61-
if not cfg.skip_preconditioners:
62-
g = g.float()
63-
preconditioner = preconditioners.get(name, None)
64-
if preconditioner is None:
65-
preconditioners[name] = g.mT @ g
62+
# Keep gradients in original dtype for preconditioner computation
63+
mod_grads[name] = g
64+
if cfg.skip_preconditioners:
65+
if save_index:
66+
mod_grads[name] = g.to(dtype=dtype, device="cpu", non_blocking=True)
6667
else:
67-
preconditioner.addmm_(g.mT, g)
68+
mod_grads[name] = g.to(dtype=dtype)
6869

6970
collector = GradientCollector(
7071
model.base_model,
@@ -74,6 +75,33 @@ def callback(name: str, g: torch.Tensor):
7475
attention_cfgs=attention_cfgs or {},
7576
)
7677

78+
# Determine which modules this rank owns for preconditioner computation
79+
if dist.is_initialized():
80+
num_devices = dist.get_world_size()
81+
# This list is sorted.
82+
available_modules = list(collector.shapes().keys())
83+
84+
num_modules = len(available_modules)
85+
base, remainder = divmod(num_modules, num_devices)
86+
87+
assert base > 0, "Each rank must own at least one module"
88+
89+
start_idx = rank * base + min(rank, remainder)
90+
end_idx = start_idx + base + (1 if rank < remainder else 0)
91+
owned_modules = set(available_modules[start_idx:end_idx])
92+
93+
for i, module_name in enumerate(available_modules):
94+
# Inverse of the start_idx formula
95+
module_to_rank[module_name] = (
96+
min(i // (base + 1), remainder - 1)
97+
if i < remainder * (base + 1)
98+
else remainder + (i - remainder * (base + 1)) // base
99+
)
100+
101+
print(f"Rank {rank} owns {len(owned_modules)} modules")
102+
else:
103+
owned_modules = set(collector.shapes().keys())
104+
77105
# Allocate space ahead of time for the gradients
78106
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}
79107
builder = (
@@ -89,7 +117,8 @@ def callback(name: str, g: torch.Tensor):
89117
fill_value=0.0,
90118
)
91119

92-
for indices in tqdm(batches, disable=rank != 0, desc="Building index"):
120+
# rank != 0
121+
for indices in tqdm(batches, disable=False, desc="Building index"):
93122
batch = data[indices]
94123
x, y = pad_and_tensor(
95124
batch["input_ids"], # type: ignore
@@ -132,6 +161,22 @@ def callback(name: str, g: torch.Tensor):
132161

133162
model.zero_grad()
134163

164+
# Send gradients to owning ranks and compute outer products there
165+
if not cfg.skip_preconditioners:
166+
exchange_preconditioner_gradients(
167+
mod_grads, preconditioners, module_to_rank, owned_modules, rank
168+
)
169+
170+
# Convert mod_grads to the right dtype for save_index logic
171+
if save_index:
172+
for name in mod_grads:
173+
mod_grads[name] = mod_grads[name].to(
174+
device="cpu", dtype=dtype, non_blocking=True
175+
)
176+
else:
177+
for name in mod_grads:
178+
mod_grads[name] = mod_grads[name].to(dtype=dtype)
179+
135180
if builder is not None:
136181
builder(indices, mod_grads)
137182

@@ -141,7 +186,8 @@ def callback(name: str, g: torch.Tensor):
141186
mod_grads.clear()
142187
per_doc_losses[indices] = losses.detach().type_as(per_doc_losses)
143188

144-
process_preconditioners(processor, preconditioners, len(data))
189+
if not cfg.skip_preconditioners:
190+
process_preconditioners(processor, preconditioners, len(data), grad_sizes, rank)
145191

146192
if dist.is_initialized():
147193
dist.reduce(per_doc_losses, dst=0)
@@ -266,58 +312,175 @@ def dist_reduce(self):
266312
self.in_memory_grad_buffer.cpu().numpy().astype(self.grad_buffer.dtype)
267313
)
268314

315+
self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu()
316+
317+
318+
def exchange_preconditioner_gradients(
319+
mod_grads: dict[str, torch.Tensor],
320+
preconditioners: dict[str, torch.Tensor],
321+
module_to_rank: dict[str, int],
322+
owned_modules: set[str],
323+
rank: int,
324+
):
325+
"""
326+
Send gradients to the ranks that own their preconditioners, and accumulate
327+
outer products on the owning ranks.
328+
Each rank sends gradients for modules it doesn't own to the owning ranks,
329+
and receives gradients for modules it owns to compute outer products.
330+
"""
331+
# Process current rank data for owned modules
332+
for name, g in mod_grads.items():
333+
if name not in owned_modules:
334+
continue
335+
336+
g = g.float()
337+
if name in preconditioners:
338+
preconditioners[name].addmm_(g.mT, g)
339+
else:
340+
preconditioners[name] = g.mT @ g
341+
342+
if not dist.is_initialized():
343+
return
344+
345+
world_size = dist.get_world_size()
346+
device = next(iter(mod_grads.values())).device
347+
348+
module_names = list(mod_grads.keys())
349+
module_numel = {n: int(mod_grads[n].numel()) for n in module_names}
350+
351+
current_rank_chunk = torch.empty(0, device=device, dtype=torch.float32)
352+
353+
# Flatten batch dimension: all to all works on contiguous 1-D tensors
354+
send_chunks = [
355+
(
356+
current_rank_chunk
357+
if dest == rank
358+
else torch.cat(
359+
[
360+
mod_grads[name].flatten()
361+
for name in module_names
362+
if module_to_rank[name] == dest
363+
]
364+
)
365+
)
366+
for dest in range(world_size)
367+
]
368+
369+
# --- collective exchange of gradient sizes in order of mod_grads ---
370+
send_sizes = torch.tensor(
371+
[t.numel() for t in send_chunks], device=device, dtype=torch.int64
372+
)
373+
recv_sizes = torch.empty_like(send_sizes)
374+
375+
dist.all_to_all_single(recv_sizes, send_sizes)
376+
377+
# --- collective exchange of gradient in order of mod_grads ---
378+
send_buf = torch.cat(send_chunks)
379+
recv_buf = torch.empty(
380+
int(recv_sizes.sum().item()), device=device, dtype=torch.float32
381+
)
382+
383+
dist.all_to_all_single(
384+
recv_buf,
385+
send_buf,
386+
output_split_sizes=recv_sizes.tolist(),
387+
input_split_sizes=send_sizes.tolist(),
388+
)
389+
390+
# Unpack gradients in src-rank order
391+
# Within each src partition, modules are in fixed order.
392+
offset = 0
393+
for src_rank in range(world_size):
394+
part_len = int(recv_sizes[src_rank].item())
395+
part = recv_buf[offset : offset + part_len]
396+
offset += part_len
397+
398+
if part_len == 0 or src_rank == rank:
399+
continue
400+
401+
p = 0
402+
for name in owned_modules:
403+
n = module_numel[name]
404+
flat = part[p : p + n]
405+
p += n
406+
407+
feature_dim = mod_grads[name].shape[-1]
408+
g = flat.to(device, non_blocking=True).view(-1, feature_dim).float()
409+
410+
if name in preconditioners:
411+
preconditioners[name].addmm_(g.mT, g)
412+
else:
413+
preconditioners[name] = g.mT @ g
414+
269415

270416
def process_preconditioners(
271417
processor: GradientProcessor,
272418
preconditioners: dict[str, torch.Tensor],
273419
len_data: int,
420+
grad_sizes: dict[str, int],
421+
rank: int,
274422
):
275423
"""
276424
Aggregate preconditioners across ranks and compute their eigen decomposition
277425
distributed across all ranks.
278426
"""
279-
280-
rank = dist.get_rank() if dist.is_initialized() else 0
281-
world_size = dist.get_world_size() if dist.is_initialized() else 1
282427
preconditioners_eigen = {}
428+
429+
device = next(iter(preconditioners.values())).device
430+
dtype = next(iter(preconditioners.values())).dtype
431+
283432
if rank == 0:
284433
print("Saving preconditioners...")
285-
for name, prec in preconditioners.items():
286-
if dist.is_initialized():
287-
dist.all_reduce(prec)
288434

289-
preconditioners[name] = prec / len_data
290-
291-
processor.preconditioners = preconditioners
435+
for name, prec in preconditioners.items():
436+
preconditioners[name] = (prec / len_data).cpu()
292437

293438
if rank == 0:
294439
print("Computing preconditioner eigen decompositions...")
295-
names = list(preconditioners.keys())
296-
names_per_rank = names[rank::world_size]
297440

298-
for name in names_per_rank:
299-
original_dtype = preconditioners[name].dtype
300-
prec = preconditioners[name].to(dtype=torch.float64)
441+
for name in preconditioners.keys():
442+
prec = preconditioners[name].to(dtype=torch.float64, device=device)
301443
eigvals, eigvecs = torch.linalg.eigh(prec)
302444
preconditioners_eigen[name] = (
303-
eigvals.to(dtype=original_dtype).contiguous(),
304-
eigvecs.to(dtype=original_dtype).contiguous(),
445+
eigvals.to(dtype=dtype).contiguous().cpu(),
446+
eigvecs.to(dtype=dtype).contiguous().cpu(),
305447
)
306448

307449
if rank == 0:
308-
print("Gathering and saving preconditioner eigen decompositions...")
450+
print("Gathering preconditioners...")
451+
452+
cpu_group = dist.new_group(backend="gloo")
453+
454+
for name, grad_size in grad_sizes.items():
455+
if name in preconditioners:
456+
local_prec = preconditioners[name]
457+
del preconditioners[name]
458+
else:
459+
local_prec = torch.zeros([grad_size, grad_size], dtype=dtype, device="cpu")
460+
461+
dist.reduce(local_prec, dst=0, op=dist.ReduceOp.SUM, group=cpu_group)
309462

310-
for name in names:
311-
prec = preconditioners[name]
463+
if rank == 0:
464+
preconditioners[name] = local_prec
465+
466+
if rank == 0:
467+
processor.preconditioners = preconditioners
468+
469+
print("Gathering eigen decompositions...")
470+
471+
for name, grad_size in grad_sizes.items():
472+
prec_size = torch.Size([grad_size, grad_size])
312473
if name not in preconditioners_eigen:
313-
eigval = torch.zeros(prec.size(0), dtype=prec.dtype, device=prec.device)
314-
eigvec = torch.zeros_like(prec)
474+
eigval = torch.zeros(prec_size[0], dtype=dtype)
475+
eigvec = torch.zeros(prec_size, dtype=dtype)
315476
else:
316477
eigval, eigvec = preconditioners_eigen[name]
317478

318-
dist.all_reduce(eigval, op=dist.ReduceOp.SUM) if dist.is_initialized() else None
319-
dist.all_reduce(eigvec, op=dist.ReduceOp.SUM) if dist.is_initialized() else None
479+
dist.reduce(eigval, dst=0, op=dist.ReduceOp.SUM, group=cpu_group)
480+
dist.reduce(eigvec, dst=0, op=dist.ReduceOp.SUM, group=cpu_group)
481+
482+
if rank == 0:
483+
preconditioners_eigen[name] = (eigval, eigvec)
320484

321-
preconditioners_eigen[name] = (eigval, eigvec)
322485
if rank == 0:
323486
processor.preconditioners_eigen = preconditioners_eigen

0 commit comments

Comments
 (0)