From 133ec0e79b93a5cba0c886a97d0f73a930345aa5 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 12 Dec 2025 02:31:39 +0000 Subject: [PATCH 1/6] save --- bergson/__main__.py | 13 +++++++---- bergson/config.py | 3 +++ bergson/distributed.py | 53 ++++++++++++++++++++---------------------- bergson/reduce.py | 23 +++++++++++++++--- 4 files changed, 56 insertions(+), 36 deletions(-) diff --git a/bergson/__main__.py b/bergson/__main__.py index 6187ea7f..70421b45 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -12,8 +12,11 @@ from .score.score import score_dataset -def validate_run_path(run_path: Path): +def validate_run_path(run_path: Path, overwrite: bool): """Validate the run path.""" + if overwrite: + return + if run_path.exists(): print(f"Run path {run_path} already exists.") response = input("Do you want to overwrite the existing run path? (y/n): ") @@ -36,8 +39,8 @@ def execute(self): run_path = Path(self.index_cfg.run_path) partial_run_path = Path(self.index_cfg.partial_run_path) - validate_run_path(run_path) - validate_run_path(partial_run_path) + validate_run_path(run_path, self.index_cfg.overwrite) + validate_run_path(partial_run_path, self.index_cfg.overwrite) build(self.index_cfg) @@ -60,8 +63,8 @@ def execute(self): run_path = Path(self.index_cfg.run_path) partial_run_path = Path(self.index_cfg.partial_run_path) - validate_run_path(run_path) - validate_run_path(partial_run_path) + validate_run_path(run_path, self.index_cfg.overwrite) + validate_run_path(partial_run_path, self.index_cfg.overwrite) reduce(self.index_cfg, self.reduce_cfg) diff --git a/bergson/config.py b/bergson/config.py index 9a70912d..c704f214 100644 --- a/bergson/config.py +++ b/bergson/config.py @@ -123,6 +123,9 @@ class IndexConfig: """Configuration for each attention module to be split into head matrices. Used for attention modules specified in `split_attention_modules`.""" + overwrite: bool = False + """Whether to overwrite an existing index without asking for confirmation.""" + @property def partial_run_path(self) -> Path: """Temporary path to use while writing build artifacts.""" diff --git a/bergson/distributed.py b/bergson/distributed.py index 222754d0..80c88239 100644 --- a/bergson/distributed.py +++ b/bergson/distributed.py @@ -1,5 +1,6 @@ import socket from typing import Any, Callable +import os import torch import torch.distributed as dist @@ -25,54 +26,50 @@ def dist_worker( def launch_distributed_run(process_name: str, worker, const_worker_args: list[Any]): - """ - Launch a distributed multi-process job over all visible CUDA devices. + local_world_size = torch.cuda.device_count() + + # Check for multi-node environment + if "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + node_rank = int(os.environ.get("RANK", os.environ.get("NODE_RANK", 0))) + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ.get("MASTER_PORT", "29500") + else: + world_size = local_world_size + node_rank = 0 + master_addr = "localhost" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + _, master_port = s.getsockname() + master_port = str(master_port) - Parameters - ---------- - process_name : str - Label used by Torch Elastic to tag logs and processes. - worker : Callable - Function that will be executed on every spawned process. It must accept - ``(rank, world_size, *const_worker_args)`` in that order. - const_worker_args : list - Arguments passed verbatim to every worker invocation after ``rank`` and - ``world_size``. These are typically configuration or shared datasets. - """ - world_size = torch.cuda.device_count() if world_size <= 1: - # Run the worker directly if no distributed training is needed. This is great - # for debugging purposes. worker(0, 1, *const_worker_args) else: - # Set up multiprocessing and distributed training mp.set_sharing_strategy("file_system") - # Find an available port for distributed training - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - _, port = s.getsockname() - ctx = None try: ctx = start_processes( process_name, dist_worker, args={ - i: (worker, i, world_size, *const_worker_args) - for i in range(world_size) + i: (worker, node_rank * local_world_size + i, world_size, *const_worker_args) + for i in range(local_world_size) }, envs={ i: { "LOCAL_RANK": str(i), - "MASTER_ADDR": "localhost", - "MASTER_PORT": str(port), + "RANK": str(node_rank * local_world_size + i), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, } - for i in range(world_size) + for i in range(local_world_size) }, logs_specs=DefaultLogsSpecs(), ) ctx.wait() finally: if ctx is not None: - ctx.close() # Kill any processes that are still running + ctx.close() \ No newline at end of file diff --git a/bergson/reduce.py b/bergson/reduce.py index 149cb994..21e26716 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -42,7 +42,9 @@ def reduce_worker( ds : Dataset | IterableDataset The entire dataset to be indexed. A subset is assigned to each worker. """ - torch.cuda.set_device(rank) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + torch.cuda.set_device(local_rank) + # torch.cuda.set_device(rank) # These should be set by the main process if world_size > 1: @@ -52,7 +54,7 @@ def reduce_worker( dist.init_process_group( "nccl", init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), + device_id=torch.device(f"cuda:{local_rank}"), rank=rank, timeout=timedelta(hours=1), world_size=world_size, @@ -128,4 +130,19 @@ def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): launch_distributed_run("reduce", reduce_worker, [index_cfg, reduce_cfg, ds]) - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + local_world_size = torch.cuda.device_count() + if "WORLD_SIZE" in os.environ: + # Multi-node setup: determine node rank + node_rank = int(os.environ.get("NODE_RANK", 0)) + # If NODE_RANK is not set but RANK is, calculate from RANK + if "NODE_RANK" not in os.environ and "RANK" in os.environ and local_world_size > 0: + node_rank = int(os.environ["RANK"]) // local_world_size + else: + # Single-node setup + node_rank = 0 + + # Only node 0 should perform the final move + if node_rank == 0: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + + # shutil.move(index_cfg.partial_run_path, index_cfg.run_path) From 063b1985edc79afd08d99a4865d563d295c30c95 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 12 Dec 2025 04:59:44 +0000 Subject: [PATCH 2/6] Enable FSDP across nodes with START_RANK --- bergson/__main__.py | 26 ++--- bergson/build.py | 7 +- bergson/collection.py | 241 +++++++++++++++++++++++++++++++++------- bergson/data.py | 9 +- bergson/distributed.py | 18 +-- bergson/reduce.py | 27 ++--- bergson/score/score.py | 7 +- bergson/worker_utils.py | 4 +- 8 files changed, 253 insertions(+), 86 deletions(-) diff --git a/bergson/__main__.py b/bergson/__main__.py index 70421b45..d1e37a3d 100644 --- a/bergson/__main__.py +++ b/bergson/__main__.py @@ -1,3 +1,4 @@ +import os import shutil from dataclasses import dataclass from pathlib import Path @@ -14,16 +15,18 @@ def validate_run_path(run_path: Path, overwrite: bool): """Validate the run path.""" - if overwrite: + start_rank = int(os.environ.get("START_RANK", 0)) + rank = start_rank + int(os.environ.get("RANK", 0)) + + if rank != 0 or not run_path.exists(): return - if run_path.exists(): - print(f"Run path {run_path} already exists.") - response = input("Do you want to overwrite the existing run path? (y/n): ") - if response.lower() != "y": - exit() - else: - shutil.rmtree(run_path) + if overwrite: + shutil.rmtree(run_path) + else: + raise FileExistsError( + f"Run path {run_path} already exists. Use --overwrite to overwrite it." + ) @dataclass @@ -55,14 +58,9 @@ class Reduce: def execute(self): """Reduce a gradient index.""" - if self.index_cfg.projection_dim != 0: - print( - "Warning: projection_dim is not 0. " - "Compressed gradients will be reduced." - ) - run_path = Path(self.index_cfg.run_path) partial_run_path = Path(self.index_cfg.partial_run_path) + validate_run_path(run_path, self.index_cfg.overwrite) validate_run_path(partial_run_path, self.index_cfg.overwrite) diff --git a/bergson/build.py b/bergson/build.py index 58e0e945..e9a89a2f 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -99,6 +99,9 @@ def flush(kwargs): if rank == 0: processor.save(cfg.partial_run_path) + if dist.is_initialized(): + dist.barrier() + def build(index_cfg: IndexConfig): """ @@ -118,4 +121,6 @@ def build(index_cfg: IndexConfig): launch_distributed_run("build", build_worker, [index_cfg, ds]) - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) + if rank == 0: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) diff --git a/bergson/collection.py b/bergson/collection.py index dab1ce11..e8b3f86e 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -40,6 +40,11 @@ def collect_gradients( if batches is None: batches = [[idx] for idx in range(len(data))] + print( + f"Rank {rank} has {len(batches)} batches and thinks world " + f"size is {dist.get_world_size()}." + ) + # Mutable state for the GradientCollector callback mod_grads = {} preconditioners = processor.preconditioners @@ -49,22 +54,18 @@ def collect_gradients( lo = torch.finfo(dtype).min hi = torch.finfo(dtype).max + owned_modules: set[str] = set() + module_to_rank: dict[str, int] = {} + def callback(name: str, g: torch.Tensor): g = g.flatten(1).clamp_(lo, hi) - if save_index: - # Asynchronously move the gradient to CPU and convert to the final dtype - mod_grads[name] = g.to(device="cpu", dtype=dtype, non_blocking=True) - else: - mod_grads[name] = g.to(dtype=dtype) - - # Compute the outer product of the flattened gradient - if not cfg.skip_preconditioners: - g = g.float() - preconditioner = preconditioners.get(name, None) - if preconditioner is None: - preconditioners[name] = g.mT @ g + # Keep gradients in original dtype for preconditioner computation + mod_grads[name] = g + if cfg.skip_preconditioners: + if save_index: + mod_grads[name] = g.to(dtype=dtype, device="cpu", non_blocking=True) else: - preconditioner.addmm_(g.mT, g) + mod_grads[name] = g.to(dtype=dtype) collector = GradientCollector( model.base_model, @@ -74,6 +75,33 @@ def callback(name: str, g: torch.Tensor): attention_cfgs=attention_cfgs or {}, ) + # Determine which modules this rank owns for preconditioner computation + if dist.is_initialized(): + num_devices = dist.get_world_size() + # This list is sorted. + available_modules = list(collector.shapes().keys()) + + num_modules = len(available_modules) + base, remainder = divmod(num_modules, num_devices) + + assert base > 0, "Each rank must own at least one module" + + start_idx = rank * base + min(rank, remainder) + end_idx = start_idx + base + (1 if rank < remainder else 0) + owned_modules = set(available_modules[start_idx:end_idx]) + + for i, module_name in enumerate(available_modules): + # Inverse of the start_idx formula + module_to_rank[module_name] = ( + min(i // (base + 1), remainder - 1) + if i < remainder * (base + 1) + else remainder + (i - remainder * (base + 1)) // base + ) + + print(f"Rank {rank} owns {len(owned_modules)} modules") + else: + owned_modules = set(collector.shapes().keys()) + # Allocate space ahead of time for the gradients grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()} builder = ( @@ -89,7 +117,8 @@ def callback(name: str, g: torch.Tensor): fill_value=0.0, ) - for indices in tqdm(batches, disable=rank != 0, desc="Building index"): + # rank != 0 + for indices in tqdm(batches, disable=False, desc="Building index"): batch = data[indices] x, y = pad_and_tensor( batch["input_ids"], # type: ignore @@ -132,6 +161,22 @@ def callback(name: str, g: torch.Tensor): model.zero_grad() + # Send gradients to owning ranks and compute outer products there + if not cfg.skip_preconditioners: + exchange_preconditioner_gradients( + mod_grads, preconditioners, module_to_rank, owned_modules, rank + ) + + # Convert mod_grads to the right dtype for save_index logic + if save_index: + for name in mod_grads: + mod_grads[name] = mod_grads[name].to( + device="cpu", dtype=dtype, non_blocking=True + ) + else: + for name in mod_grads: + mod_grads[name] = mod_grads[name].to(dtype=dtype) + if builder is not None: builder(indices, mod_grads) @@ -141,7 +186,8 @@ def callback(name: str, g: torch.Tensor): mod_grads.clear() per_doc_losses[indices] = losses.detach().type_as(per_doc_losses) - process_preconditioners(processor, preconditioners, len(data)) + if not cfg.skip_preconditioners: + process_preconditioners(processor, preconditioners, len(data), grad_sizes, rank) if dist.is_initialized(): dist.reduce(per_doc_losses, dst=0) @@ -266,58 +312,175 @@ def dist_reduce(self): self.in_memory_grad_buffer.cpu().numpy().astype(self.grad_buffer.dtype) ) + self.in_memory_grad_buffer = self.in_memory_grad_buffer.cpu() + + +def exchange_preconditioner_gradients( + mod_grads: dict[str, torch.Tensor], + preconditioners: dict[str, torch.Tensor], + module_to_rank: dict[str, int], + owned_modules: set[str], + rank: int, +): + """ + Send gradients to the ranks that own their preconditioners, and accumulate + outer products on the owning ranks. + Each rank sends gradients for modules it doesn't own to the owning ranks, + and receives gradients for modules it owns to compute outer products. + """ + # Process current rank data for owned modules + for name, g in mod_grads.items(): + if name not in owned_modules: + continue + + g = g.float() + if name in preconditioners: + preconditioners[name].addmm_(g.mT, g) + else: + preconditioners[name] = g.mT @ g + + if not dist.is_initialized(): + return + + world_size = dist.get_world_size() + device = next(iter(mod_grads.values())).device + + module_names = list(mod_grads.keys()) + module_numel = {n: int(mod_grads[n].numel()) for n in module_names} + + current_rank_chunk = torch.empty(0, device=device, dtype=torch.float32) + + # Flatten batch dimension: all to all works on contiguous 1-D tensors + send_chunks = [ + ( + current_rank_chunk + if dest == rank + else torch.cat( + [ + mod_grads[name].flatten() + for name in module_names + if module_to_rank[name] == dest + ] + ) + ) + for dest in range(world_size) + ] + + # --- collective exchange of gradient sizes in order of mod_grads --- + send_sizes = torch.tensor( + [t.numel() for t in send_chunks], device=device, dtype=torch.int64 + ) + recv_sizes = torch.empty_like(send_sizes) + + dist.all_to_all_single(recv_sizes, send_sizes) + + # --- collective exchange of gradient in order of mod_grads --- + send_buf = torch.cat(send_chunks) + recv_buf = torch.empty( + int(recv_sizes.sum().item()), device=device, dtype=torch.float32 + ) + + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_sizes.tolist(), + input_split_sizes=send_sizes.tolist(), + ) + + # Unpack gradients in src-rank order + # Within each src partition, modules are in fixed order. + offset = 0 + for src_rank in range(world_size): + part_len = int(recv_sizes[src_rank].item()) + part = recv_buf[offset : offset + part_len] + offset += part_len + + if part_len == 0 or src_rank == rank: + continue + + p = 0 + for name in owned_modules: + n = module_numel[name] + flat = part[p : p + n] + p += n + + feature_dim = mod_grads[name].shape[-1] + g = flat.to(device, non_blocking=True).view(-1, feature_dim).float() + + if name in preconditioners: + preconditioners[name].addmm_(g.mT, g) + else: + preconditioners[name] = g.mT @ g + def process_preconditioners( processor: GradientProcessor, preconditioners: dict[str, torch.Tensor], len_data: int, + grad_sizes: dict[str, int], + rank: int, ): """ Aggregate preconditioners across ranks and compute their eigen decomposition distributed across all ranks. """ - - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 preconditioners_eigen = {} + + device = next(iter(preconditioners.values())).device + dtype = next(iter(preconditioners.values())).dtype + if rank == 0: print("Saving preconditioners...") - for name, prec in preconditioners.items(): - if dist.is_initialized(): - dist.all_reduce(prec) - preconditioners[name] = prec / len_data - - processor.preconditioners = preconditioners + for name, prec in preconditioners.items(): + preconditioners[name] = (prec / len_data).cpu() if rank == 0: print("Computing preconditioner eigen decompositions...") - names = list(preconditioners.keys()) - names_per_rank = names[rank::world_size] - for name in names_per_rank: - original_dtype = preconditioners[name].dtype - prec = preconditioners[name].to(dtype=torch.float64) + for name in preconditioners.keys(): + prec = preconditioners[name].to(dtype=torch.float64, device=device) eigvals, eigvecs = torch.linalg.eigh(prec) preconditioners_eigen[name] = ( - eigvals.to(dtype=original_dtype).contiguous(), - eigvecs.to(dtype=original_dtype).contiguous(), + eigvals.to(dtype=dtype).contiguous().cpu(), + eigvecs.to(dtype=dtype).contiguous().cpu(), ) if rank == 0: - print("Gathering and saving preconditioner eigen decompositions...") + print("Gathering preconditioners...") + + cpu_group = dist.new_group(backend="gloo") + + for name, grad_size in grad_sizes.items(): + if name in preconditioners: + local_prec = preconditioners[name] + del preconditioners[name] + else: + local_prec = torch.zeros([grad_size, grad_size], dtype=dtype, device="cpu") + + dist.reduce(local_prec, dst=0, op=dist.ReduceOp.SUM, group=cpu_group) - for name in names: - prec = preconditioners[name] + if rank == 0: + preconditioners[name] = local_prec + + if rank == 0: + processor.preconditioners = preconditioners + + print("Gathering eigen decompositions...") + + for name, grad_size in grad_sizes.items(): + prec_size = torch.Size([grad_size, grad_size]) if name not in preconditioners_eigen: - eigval = torch.zeros(prec.size(0), dtype=prec.dtype, device=prec.device) - eigvec = torch.zeros_like(prec) + eigval = torch.zeros(prec_size[0], dtype=dtype) + eigvec = torch.zeros(prec_size, dtype=dtype) else: eigval, eigvec = preconditioners_eigen[name] - dist.all_reduce(eigval, op=dist.ReduceOp.SUM) if dist.is_initialized() else None - dist.all_reduce(eigvec, op=dist.ReduceOp.SUM) if dist.is_initialized() else None + dist.reduce(eigval, dst=0, op=dist.ReduceOp.SUM, group=cpu_group) + dist.reduce(eigvec, dst=0, op=dist.ReduceOp.SUM, group=cpu_group) + + if rank == 0: + preconditioners_eigen[name] = (eigval, eigvec) - preconditioners_eigen[name] = (eigval, eigvec) if rank == 0: processor.preconditioners_eigen = preconditioners_eigen diff --git a/bergson/data.py b/bergson/data.py index fab66a8e..b913d36f 100644 --- a/bergson/data.py +++ b/bergson/data.py @@ -28,7 +28,11 @@ def ceildiv(a: int, b: int) -> int: return -(-a // b) # Equivalent to math.ceil(a / b) but faster for integers -def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[list[int]]: +def allocate_batches( + doc_lengths: list[int], + N: int, + seed: int = 42, +) -> list[list[int]]: """ Allocate documents into batches that are then distributed evenly across a fixed number of workers. @@ -41,7 +45,8 @@ def allocate_batches(doc_lengths: list[int], N: int, seed: int = 42) -> list[lis N : int Hard memory budget per *batch*, expressed as ``max(length in batch) * (# docs in batch) ≤ N``. - + seed : int + Random seed for shuffling batches within each worker's allocation. Returns ------- list[list[int]] diff --git a/bergson/distributed.py b/bergson/distributed.py index 80c88239..e80962cb 100644 --- a/bergson/distributed.py +++ b/bergson/distributed.py @@ -1,6 +1,6 @@ +import os import socket from typing import Any, Callable -import os import torch import torch.distributed as dist @@ -27,16 +27,18 @@ def dist_worker( def launch_distributed_run(process_name: str, worker, const_worker_args: list[Any]): local_world_size = torch.cuda.device_count() - - # Check for multi-node environment + + # Multi-node environment if "WORLD_SIZE" in os.environ: world_size = int(os.environ["WORLD_SIZE"]) - node_rank = int(os.environ.get("RANK", os.environ.get("NODE_RANK", 0))) + # Starting rank for this node + start_rank = int(os.environ["START_RANK"]) master_addr = os.environ["MASTER_ADDR"] master_port = os.environ.get("MASTER_PORT", "29500") else: world_size = local_world_size - node_rank = 0 + # Starting rank for this node + start_rank = 0 master_addr = "localhost" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) @@ -54,13 +56,13 @@ def launch_distributed_run(process_name: str, worker, const_worker_args: list[An process_name, dist_worker, args={ - i: (worker, node_rank * local_world_size + i, world_size, *const_worker_args) + i: (worker, start_rank + i, i, world_size, *const_worker_args) for i in range(local_world_size) }, envs={ i: { "LOCAL_RANK": str(i), - "RANK": str(node_rank * local_world_size + i), + "RANK": str(start_rank + i), "WORLD_SIZE": str(world_size), "MASTER_ADDR": master_addr, "MASTER_PORT": master_port, @@ -72,4 +74,4 @@ def launch_distributed_run(process_name: str, worker, const_worker_args: list[An ctx.wait() finally: if ctx is not None: - ctx.close() \ No newline at end of file + ctx.close() # Kill any processes that are still running diff --git a/bergson/reduce.py b/bergson/reduce.py index 21e26716..8667887d 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -21,6 +21,7 @@ def reduce_worker( rank: int, + local_rank: int, world_size: int, index_cfg: IndexConfig, reduce_cfg: ReduceConfig, @@ -42,9 +43,7 @@ def reduce_worker( ds : Dataset | IterableDataset The entire dataset to be indexed. A subset is assigned to each worker. """ - local_rank = int(os.environ.get("LOCAL_RANK", rank)) torch.cuda.set_device(local_rank) - # torch.cuda.set_device(rank) # These should be set by the main process if world_size > 1: @@ -56,11 +55,11 @@ def reduce_worker( init_method=f"tcp://{addr}:{port}", device_id=torch.device(f"cuda:{local_rank}"), rank=rank, - timeout=timedelta(hours=1), + timeout=timedelta(minutes=30), world_size=world_size, ) - model, target_modules = setup_model_and_peft(index_cfg, rank) + model, target_modules = setup_model_and_peft(index_cfg, local_rank) processor = create_processor(index_cfg, rank) attention_cfgs = { @@ -109,6 +108,9 @@ def flush(kwargs): if rank == 0: processor.save(index_cfg.partial_run_path) + if dist.is_initialized(): + dist.barrier() + def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): """ @@ -130,19 +132,6 @@ def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): launch_distributed_run("reduce", reduce_worker, [index_cfg, reduce_cfg, ds]) - local_world_size = torch.cuda.device_count() - if "WORLD_SIZE" in os.environ: - # Multi-node setup: determine node rank - node_rank = int(os.environ.get("NODE_RANK", 0)) - # If NODE_RANK is not set but RANK is, calculate from RANK - if "NODE_RANK" not in os.environ and "RANK" in os.environ and local_world_size > 0: - node_rank = int(os.environ["RANK"]) // local_world_size - else: - # Single-node setup - node_rank = 0 - - # Only node 0 should perform the final move - if node_rank == 0: + rank = int(os.environ.get("START_RANK", os.environ.get("RANK", 0))) + if rank == 0: shutil.move(index_cfg.partial_run_path, index_cfg.run_path) - - # shutil.move(index_cfg.partial_run_path, index_cfg.run_path) diff --git a/bergson/score/score.py b/bergson/score/score.py index 2382bab3..1bb84a02 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -316,6 +316,9 @@ def flush(kwargs): if rank == 0: processor.save(index_cfg.partial_run_path) + if dist.is_initialized(): + dist.barrier() + def score_dataset(index_cfg: IndexConfig, score_cfg: ScoreConfig): """ @@ -352,4 +355,6 @@ def score_dataset(index_cfg: IndexConfig, score_cfg: ScoreConfig): "score", score_worker, [index_cfg, score_cfg, ds, query_grads] ) - shutil.move(index_cfg.partial_run_path, index_cfg.run_path) + rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0))) + if rank == 0: + shutil.move(index_cfg.partial_run_path, index_cfg.run_path) diff --git a/bergson/worker_utils.py b/bergson/worker_utils.py index 8898b821..d7bcd3b0 100644 --- a/bergson/worker_utils.py +++ b/bergson/worker_utils.py @@ -49,7 +49,7 @@ def create_processor( def setup_model_and_peft( cfg: IndexConfig, - rank: int, + local_rank: int, ) -> tuple[AutoModelForCausalLM, set | None]: """Handle model loading, quantization, FSDP, and PEFT detection""" @@ -68,7 +68,7 @@ def setup_model_and_peft( raise ValueError(f"Unsupported precision: {other}") # Common configuration - device_map = {"": f"cuda:{rank}"} if not cfg.fsdp else "cpu" + device_map = {"": f"cuda:{local_rank}"} if not cfg.fsdp else "cpu" quantization_config = None if cfg.precision in ("int4", "int8"): quantization_config = BitsAndBytesConfig( From 51e1b08ca13730b8d85c4dc11fb0a0a5c587ed59 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 15 Dec 2025 22:27:56 +0000 Subject: [PATCH 3/6] Remove final dist barrier --- bergson/build.py | 3 --- bergson/collection.py | 11 ++++++----- bergson/reduce.py | 3 --- bergson/score/score.py | 3 --- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/bergson/build.py b/bergson/build.py index e9a89a2f..c5a51e8c 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -99,9 +99,6 @@ def flush(kwargs): if rank == 0: processor.save(cfg.partial_run_path) - if dist.is_initialized(): - dist.barrier() - def build(index_cfg: IndexConfig): """ diff --git a/bergson/collection.py b/bergson/collection.py index e8b3f86e..c5be3ee9 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -192,6 +192,12 @@ def callback(name: str, g: torch.Tensor): if dist.is_initialized(): dist.reduce(per_doc_losses, dst=0) + # Finalize disk IO + if builder is not None: + builder.flush() + # Final collective operation + builder.dist_reduce() + if rank == 0: if cfg.drop_columns: data = data.remove_columns(["input_ids"]) @@ -207,11 +213,6 @@ def callback(name: str, g: torch.Tensor): processor.save(cfg.partial_run_path) - # Make sure the gradients are written to disk - if builder is not None: - builder.flush() - builder.dist_reduce() - class Builder: num_items: int diff --git a/bergson/reduce.py b/bergson/reduce.py index 8667887d..ae8aeee5 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -108,9 +108,6 @@ def flush(kwargs): if rank == 0: processor.save(index_cfg.partial_run_path) - if dist.is_initialized(): - dist.barrier() - def reduce(index_cfg: IndexConfig, reduce_cfg: ReduceConfig): """ diff --git a/bergson/score/score.py b/bergson/score/score.py index 1bb84a02..7557773c 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -316,9 +316,6 @@ def flush(kwargs): if rank == 0: processor.save(index_cfg.partial_run_path) - if dist.is_initialized(): - dist.barrier() - def score_dataset(index_cfg: IndexConfig, score_cfg: ScoreConfig): """ From 88a11a07a060c2782e308d936885d9444909d0cb Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 16 Dec 2025 00:15:49 +0000 Subject: [PATCH 4/6] fix tests --- bergson/build.py | 9 +++++---- bergson/collection.py | 14 ++++++++++---- bergson/distributed.py | 2 +- bergson/reduce.py | 2 +- bergson/score/score.py | 9 +++++---- bergson/worker_utils.py | 5 +++-- 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/bergson/build.py b/bergson/build.py index c5a51e8c..252f0933 100644 --- a/bergson/build.py +++ b/bergson/build.py @@ -21,6 +21,7 @@ def build_worker( rank: int, + local_rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableDataset, @@ -39,7 +40,7 @@ def build_worker( ds : Dataset | IterableDataset The entire dataset to be indexed. A subset is assigned to each worker. """ - torch.cuda.set_device(rank) + torch.cuda.set_device(local_rank) # These should be set by the main process if world_size > 1: @@ -49,14 +50,14 @@ def build_worker( dist.init_process_group( "nccl", init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), + device_id=torch.device(f"cuda:{local_rank}"), rank=rank, timeout=timedelta(hours=1), world_size=world_size, ) - model, target_modules = setup_model_and_peft(cfg, rank) - processor = create_processor(cfg, rank) + model, target_modules = setup_model_and_peft(cfg, local_rank) + processor = create_processor(cfg, local_rank, rank) attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules} diff --git a/bergson/collection.py b/bergson/collection.py index c5be3ee9..84254e69 100644 --- a/bergson/collection.py +++ b/bergson/collection.py @@ -40,10 +40,11 @@ def collect_gradients( if batches is None: batches = [[idx] for idx in range(len(data))] - print( - f"Rank {rank} has {len(batches)} batches and thinks world " - f"size is {dist.get_world_size()}." - ) + if dist.is_initialized(): + print( + f"Rank {rank} has {len(batches)} batches and world size " + f"{dist.get_world_size()}." + ) # Mutable state for the GradientCollector callback mod_grads = {} @@ -447,6 +448,11 @@ def process_preconditioners( eigvecs.to(dtype=dtype).contiguous().cpu(), ) + if not dist.is_initialized(): + processor.preconditioners = preconditioners + processor.preconditioners_eigen = preconditioners_eigen + return + if rank == 0: print("Gathering preconditioners...") diff --git a/bergson/distributed.py b/bergson/distributed.py index e80962cb..2a34990e 100644 --- a/bergson/distributed.py +++ b/bergson/distributed.py @@ -46,7 +46,7 @@ def launch_distributed_run(process_name: str, worker, const_worker_args: list[An master_port = str(master_port) if world_size <= 1: - worker(0, 1, *const_worker_args) + worker(0, 0, 1, *const_worker_args) else: mp.set_sharing_strategy("file_system") diff --git a/bergson/reduce.py b/bergson/reduce.py index ae8aeee5..1c790acb 100644 --- a/bergson/reduce.py +++ b/bergson/reduce.py @@ -60,7 +60,7 @@ def reduce_worker( ) model, target_modules = setup_model_and_peft(index_cfg, local_rank) - processor = create_processor(index_cfg, rank) + processor = create_processor(index_cfg, local_rank, rank) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules diff --git a/bergson/score/score.py b/bergson/score/score.py index 7557773c..29048fe1 100644 --- a/bergson/score/score.py +++ b/bergson/score/score.py @@ -208,6 +208,7 @@ def precondition(batch): def score_worker( rank: int, + local_rank: int, world_size: int, index_cfg: IndexConfig, score_cfg: ScoreConfig, @@ -233,7 +234,7 @@ def score_worker( query_grads : dict[str, torch.Tensor] Preprocessed query gradient tensors (often [1, grad_dim]) keyed by module name. """ - torch.cuda.set_device(rank) + torch.cuda.set_device(local_rank) # These should be set by the main process if world_size > 1: @@ -243,15 +244,15 @@ def score_worker( dist.init_process_group( "nccl", init_method=f"tcp://{addr}:{port}", - device_id=torch.device(f"cuda:{rank}"), + device_id=torch.device(f"cuda:{local_rank}"), rank=rank, timeout=timedelta(hours=1), world_size=world_size, ) - model, target_modules = setup_model_and_peft(index_cfg, rank) + model, target_modules = setup_model_and_peft(index_cfg, local_rank) model = cast(PreTrainedModel, model) - processor = create_processor(index_cfg, rank) + processor = create_processor(index_cfg, local_rank, rank) attention_cfgs = { module: index_cfg.attention for module in index_cfg.split_attention_modules diff --git a/bergson/worker_utils.py b/bergson/worker_utils.py index d7bcd3b0..530ae4af 100644 --- a/bergson/worker_utils.py +++ b/bergson/worker_utils.py @@ -21,17 +21,18 @@ def create_processor( cfg: IndexConfig, + local_rank: int, rank: int, ) -> GradientProcessor: """Handle processor creation and normalizer fitting""" processor_path = Path(cfg.processor_path) if (processor_path / "processor_config.json").exists(): - if rank == 0: + if local_rank == 0: print(f"Loading processor from '{cfg.processor_path}'") processor = GradientProcessor.load( processor_path, - map_location=f"cuda:{rank}", + map_location=f"cuda:{local_rank}", ) else: processor = GradientProcessor( From ef4c58cf78c4583854881c5d88da8da92f4c2761 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 19 Dec 2025 01:22:42 +0000 Subject: [PATCH 5/6] add --- examples/assemble_query.py | 255 +++++++++++++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) create mode 100644 examples/assemble_query.py diff --git a/examples/assemble_query.py b/examples/assemble_query.py new file mode 100644 index 00000000..debbd839 --- /dev/null +++ b/examples/assemble_query.py @@ -0,0 +1,255 @@ +# TODAY +# Assemble dataset!! 6 queries +# Try multi node generation +# I believe the MCQA and Cloze setups are pulled from the same eval and are +# both roughly 1k rows, like the original wmdp-bio as a whole. + +import os +import shutil +import subprocess +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +from datasets import ( + Dataset, + concatenate_datasets, + get_dataset_config_names, + load_dataset, +) +from numpy.lib.recfunctions import ( + structured_to_unstructured, + unstructured_to_structured, +) + +from bergson import DataConfig, IndexConfig, ReduceConfig, load_gradients +from bergson.data import create_index, load_gradient_dataset, tokenize +from bergson.utils import assert_type + + +def lm_eval_harness_format(x): + """Format the MCQA as they are for models without a chat template + in LM Eval Harness, but with the answer appended after a space.'""" + + question = x["question"] + choices = x["choices"] + + prompt = ( + f"{question.strip()}\nA. {choices[0]}\nB. {choices[1]}\n" + f"C. {choices[2]}\nD. {choices[3]}\nAnswer: {x['answer']}" + ) + + return { + "text": prompt, + "subset": x["subset"], + } + + +def load_mcqa_dataset(dataset_name="EleutherAI/wmdp_bio_robust_mcqa"): + subsets = get_dataset_config_names(dataset_name) + mcqa_datasets = [] + for subset in subsets: + ds = assert_type(Dataset, load_dataset(dataset_name, subset, split="robust")) + ds = ds.add_column("subset", [subset] * len(ds)) + mcqa_datasets.append(ds) + + return concatenate_datasets(mcqa_datasets) + + +def tokenize_mcqa( + batch: dict, + *, + tokenizer, + args: DataConfig, + answer_marker: str = "Answer:", +): + """ + Custom tokenizer for this MCQA experiment that only keeps labels on the + final answer span so gradient collection ignores the rest of the prompt. + + Codex wrote this. + """ + # TODO integrate custom masking into tokenize if necessary + return tokenize(batch, args=args, tokenizer=tokenizer, apply_chat_template=False) + + +def create_query_index( + query_ds: Dataset, run_path: str, assembled_dataset_path: str, index_dtype: np.dtype +): + structured_mmap = load_gradients(run_path) + mmap_dtype = structured_mmap.dtype + + # Copy into memory + gradient_tensor = torch.tensor(structured_to_unstructured(structured_mmap)).to( + torch.float32 + ) + + print("mmap sum", gradient_tensor.sum()) + print("mmap sum", gradient_tensor.abs().sum()) + + # Group mmap gradient rows by the subset they came from + subset_gradients = defaultdict(list) + for grads_row, ds_row in zip(gradient_tensor, query_ds): + subset_gradients[ds_row["subset"]].append(grads_row) + + subset_mean_gradients = {"overall": gradient_tensor.mean(dim=0)} + for subset, gradients in subset_gradients.items(): + mean_gradient = torch.stack(gradients).mean(dim=0) + subset_mean_gradients[subset] = mean_gradient + + # Copy everything from the origin run path to the new path + # except gradients.bin and data.hf + os.makedirs(assembled_dataset_path, exist_ok=True) + for item in os.listdir(run_path): + if item not in ["gradients.bin", "data.hf"]: + dest = Path(assembled_dataset_path) / item + shutil.copy(Path(run_path) / item, dest) + + if (Path(assembled_dataset_path) / "data.hf").exists(): + if (Path(assembled_dataset_path) / "data.hf").is_file(): + (Path(assembled_dataset_path) / "data.hf").unlink() + else: + shutil.rmtree(Path(assembled_dataset_path) / "data.hf") + + # Write structured mean queries to data.hf + np_mean_grads = np.stack( + [item.numpy() for item in list(subset_mean_gradients.values())], axis=0 + ) + # structured_np_mean_grads = unstructured_to_structured(np_mean_grads, mmap_dtype) + # data = [ + # { + # name: structured_np_mean_grads[name][i].tolist() + # for name in mmap_dtype.names + # } + # for i in range(structured_np_mean_grads.shape[0]) + # ] + + means_dataset = Dataset.from_dict( + { + "scores": [0.0] * len(subset_mean_gradients), + } + ) + means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") + + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) + first_query_grad = gradient_tensor[0].unsqueeze(0).expand_as(mean_grad_stack) + cosine_sims = torch.nn.functional.cosine_similarity( + mean_grad_stack, first_query_grad, dim=1 + ) + + # Assemble grad sizes + grad_sizes = {} + for name in mmap_dtype.names: + field_dtype = mmap_dtype.fields[name][0] + subdtype = field_dtype.subdtype + assert subdtype is not None + + _, shape = subdtype + grad_sizes[name] = int(np.prod(shape)) + + # Create and populate the index + index_grads = create_index( + str(assembled_dataset_path), len(subset_mean_gradients), grad_sizes, index_dtype + ) + index_grads[:] = unstructured_to_structured( + np_mean_grads.astype(index_dtype), mmap_dtype + ) + index_grads.flush() + + load_gradient_dataset(assembled_dataset_path) + + mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) + first_query_grad = gradient_tensor[1].unsqueeze(0).expand_as(mean_grad_stack) + cosine_sims = torch.nn.functional.cosine_similarity( + mean_grad_stack, first_query_grad, dim=1 + ) + if torch.any(cosine_sims <= 0.09): + raise ValueError( + f"Cosine similarity between mean gradients and the first query gradient " + f"is not greater than 0.09. Cosine sims: {cosine_sims}" + ) + else: + print(f"Cosine sims: {cosine_sims}") + + +def main(): + # TODO migrate to a larger model + model_name = "EleutherAI/deep_ignorance_pretraining_baseline_small" + ds_path = "runs/ds_wmdp_bio_robust_mcqa" + projection_dim = 256 + + # Spend all day on getting a setup without FSDP working. + index_path = f"runs/wmdp_bio_robust_mcqa_query_{projection_dim}" + + mcqa_ds = assert_type(Dataset, load_mcqa_dataset()) + mcqa_ds = mcqa_ds.map( + lm_eval_harness_format, remove_columns=["choices", "answer", "question"] + ) + mcqa_ds.save_to_disk(ds_path) + + # Add chat template following whatever the original deep ignorance project did + data_config = DataConfig( + dataset=ds_path, + prompt_column="text", + ) + + cfg = IndexConfig( + run_path=index_path, + # precision="fp16", + data=data_config, + # fsdp=True, + model=model_name, + projection_dim=projection_dim, + skip_preconditioners=False, + token_batch_size=800, + precision="fp16", + ) + reduce_cfg = ReduceConfig( + method="mean", + unit_normalize=True, + ) + + cmd = [ + "bergson", + "reduce", + index_path, + "--dataset", + cfg.data.dataset, + "--prompt_column", + cfg.data.prompt_column, + "--model", + cfg.model, + "--projection_dim", + str(cfg.projection_dim), + "--token_batch_size", + str(cfg.token_batch_size), + "--method", + reduce_cfg.method, + "--unit_normalize", + str(reduce_cfg.unit_normalize), + "--precision", + cfg.precision, + "--fsdp", # Need more memory available when computing the preconditioner + ] + + print(" ".join(cmd)) + exit() + + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print(result.stdout) + print(result.stderr) + + # Trackstar uses 2**16 with an 8B model + # We are collecting gradients for a ~2.7B model + # We are using ~2**13 I think + modules = set(load_gradients(cfg.run_path).dtype.names) + print( + f"Full projection dim: {len(modules) * cfg.projection_dim * cfg.projection_dim}" + ) + + exit() + + +if __name__ == "__main__": + main() From b222a3389f1db72175e8e2a0cf9b6d1a06f29750 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 19 Dec 2025 01:34:37 +0000 Subject: [PATCH 6/6] Comment out unused code --- examples/assemble_query.py | 212 ++++++++++++++++++------------------- 1 file changed, 104 insertions(+), 108 deletions(-) diff --git a/examples/assemble_query.py b/examples/assemble_query.py index debbd839..150cb2dc 100644 --- a/examples/assemble_query.py +++ b/examples/assemble_query.py @@ -4,27 +4,17 @@ # I believe the MCQA and Cloze setups are pulled from the same eval and are # both roughly 1k rows, like the original wmdp-bio as a whole. -import os -import shutil import subprocess -from collections import defaultdict -from pathlib import Path -import numpy as np -import torch from datasets import ( Dataset, concatenate_datasets, get_dataset_config_names, load_dataset, ) -from numpy.lib.recfunctions import ( - structured_to_unstructured, - unstructured_to_structured, -) from bergson import DataConfig, IndexConfig, ReduceConfig, load_gradients -from bergson.data import create_index, load_gradient_dataset, tokenize +from bergson.data import tokenize from bergson.utils import assert_type @@ -74,103 +64,109 @@ def tokenize_mcqa( return tokenize(batch, args=args, tokenizer=tokenizer, apply_chat_template=False) -def create_query_index( - query_ds: Dataset, run_path: str, assembled_dataset_path: str, index_dtype: np.dtype -): - structured_mmap = load_gradients(run_path) - mmap_dtype = structured_mmap.dtype - - # Copy into memory - gradient_tensor = torch.tensor(structured_to_unstructured(structured_mmap)).to( - torch.float32 - ) - - print("mmap sum", gradient_tensor.sum()) - print("mmap sum", gradient_tensor.abs().sum()) - - # Group mmap gradient rows by the subset they came from - subset_gradients = defaultdict(list) - for grads_row, ds_row in zip(gradient_tensor, query_ds): - subset_gradients[ds_row["subset"]].append(grads_row) - - subset_mean_gradients = {"overall": gradient_tensor.mean(dim=0)} - for subset, gradients in subset_gradients.items(): - mean_gradient = torch.stack(gradients).mean(dim=0) - subset_mean_gradients[subset] = mean_gradient - - # Copy everything from the origin run path to the new path - # except gradients.bin and data.hf - os.makedirs(assembled_dataset_path, exist_ok=True) - for item in os.listdir(run_path): - if item not in ["gradients.bin", "data.hf"]: - dest = Path(assembled_dataset_path) / item - shutil.copy(Path(run_path) / item, dest) - - if (Path(assembled_dataset_path) / "data.hf").exists(): - if (Path(assembled_dataset_path) / "data.hf").is_file(): - (Path(assembled_dataset_path) / "data.hf").unlink() - else: - shutil.rmtree(Path(assembled_dataset_path) / "data.hf") - - # Write structured mean queries to data.hf - np_mean_grads = np.stack( - [item.numpy() for item in list(subset_mean_gradients.values())], axis=0 - ) - # structured_np_mean_grads = unstructured_to_structured(np_mean_grads, mmap_dtype) - # data = [ - # { - # name: structured_np_mean_grads[name][i].tolist() - # for name in mmap_dtype.names - # } - # for i in range(structured_np_mean_grads.shape[0]) - # ] - - means_dataset = Dataset.from_dict( - { - "scores": [0.0] * len(subset_mean_gradients), - } - ) - means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") - - mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) - first_query_grad = gradient_tensor[0].unsqueeze(0).expand_as(mean_grad_stack) - cosine_sims = torch.nn.functional.cosine_similarity( - mean_grad_stack, first_query_grad, dim=1 - ) - - # Assemble grad sizes - grad_sizes = {} - for name in mmap_dtype.names: - field_dtype = mmap_dtype.fields[name][0] - subdtype = field_dtype.subdtype - assert subdtype is not None - - _, shape = subdtype - grad_sizes[name] = int(np.prod(shape)) - - # Create and populate the index - index_grads = create_index( - str(assembled_dataset_path), len(subset_mean_gradients), grad_sizes, index_dtype - ) - index_grads[:] = unstructured_to_structured( - np_mean_grads.astype(index_dtype), mmap_dtype - ) - index_grads.flush() - - load_gradient_dataset(assembled_dataset_path) - - mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) - first_query_grad = gradient_tensor[1].unsqueeze(0).expand_as(mean_grad_stack) - cosine_sims = torch.nn.functional.cosine_similarity( - mean_grad_stack, first_query_grad, dim=1 - ) - if torch.any(cosine_sims <= 0.09): - raise ValueError( - f"Cosine similarity between mean gradients and the first query gradient " - f"is not greater than 0.09. Cosine sims: {cosine_sims}" - ) - else: - print(f"Cosine sims: {cosine_sims}") +# def create_query_index( +# query_ds: Dataset, +# run_path: str, +# assembled_dataset_path: str, +# index_dtype: np.dtype +# ): +# structured_mmap = load_gradients(run_path) +# mmap_dtype = structured_mmap.dtype + +# # Copy into memory +# gradient_tensor = torch.tensor(structured_to_unstructured(structured_mmap)).to( +# torch.float32 +# ) + +# print("mmap sum", gradient_tensor.sum()) +# print("mmap sum", gradient_tensor.abs().sum()) + +# # Group mmap gradient rows by the subset they came from +# subset_gradients = defaultdict(list) +# for grads_row, ds_row in zip(gradient_tensor, query_ds): +# subset_gradients[ds_row["subset"]].append(grads_row) + +# subset_mean_gradients = {"overall": gradient_tensor.mean(dim=0)} +# for subset, gradients in subset_gradients.items(): +# mean_gradient = torch.stack(gradients).mean(dim=0) +# subset_mean_gradients[subset] = mean_gradient + +# # Copy everything from the origin run path to the new path +# # except gradients.bin and data.hf +# os.makedirs(assembled_dataset_path, exist_ok=True) +# for item in os.listdir(run_path): +# if item not in ["gradients.bin", "data.hf"]: +# dest = Path(assembled_dataset_path) / item +# shutil.copy(Path(run_path) / item, dest) + +# if (Path(assembled_dataset_path) / "data.hf").exists(): +# if (Path(assembled_dataset_path) / "data.hf").is_file(): +# (Path(assembled_dataset_path) / "data.hf").unlink() +# else: +# shutil.rmtree(Path(assembled_dataset_path) / "data.hf") + +# # Write structured mean queries to data.hf +# np_mean_grads = np.stack( +# [item.numpy() for item in list(subset_mean_gradients.values())], axis=0 +# ) +# # structured_np_mean_grads = unstructured_to_structured(np_mean_grads, mmap_dtype) +# # data = [ +# # { +# # name: structured_np_mean_grads[name][i].tolist() +# # for name in mmap_dtype.names +# # } +# # for i in range(structured_np_mean_grads.shape[0]) +# # ] + +# means_dataset = Dataset.from_dict( +# { +# "scores": [0.0] * len(subset_mean_gradients), +# } +# ) +# means_dataset.save_to_disk(Path(assembled_dataset_path) / "data.hf") + +# mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) +# first_query_grad = gradient_tensor[0].unsqueeze(0).expand_as(mean_grad_stack) +# cosine_sims = torch.nn.functional.cosine_similarity( +# mean_grad_stack, first_query_grad, dim=1 +# ) + +# # Assemble grad sizes +# grad_sizes = {} +# for name in mmap_dtype.names: +# field_dtype = mmap_dtype.fields[name][0] +# subdtype = field_dtype.subdtype +# assert subdtype is not None + +# _, shape = subdtype +# grad_sizes[name] = int(np.prod(shape)) + +# # Create and populate the index +# index_grads = create_index( +# str(assembled_dataset_path), +# len(subset_mean_gradients), +# grad_sizes, +# index_dtype +# ) +# index_grads[:] = unstructured_to_structured( +# np_mean_grads.astype(index_dtype), mmap_dtype +# ) +# index_grads.flush() + +# load_gradient_dataset(assembled_dataset_path) + +# mean_grad_stack = torch.stack(list(subset_mean_gradients.values())) +# first_query_grad = gradient_tensor[1].unsqueeze(0).expand_as(mean_grad_stack) +# cosine_sims = torch.nn.functional.cosine_similarity( +# mean_grad_stack, first_query_grad, dim=1 +# ) +# if torch.any(cosine_sims <= 0.09): +# raise ValueError( +# f"Cosine similarity between mean gradients and the first query gradient " +# f"is not greater than 0.09. Cosine sims: {cosine_sims}" +# ) +# else: +# print(f"Cosine sims: {cosine_sims}") def main():