Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -12,15 +13,20 @@
from .score.score import score_dataset


def validate_run_path(run_path: Path):
def validate_run_path(run_path: Path, overwrite: bool):
"""Validate the run path."""
if run_path.exists():
print(f"Run path {run_path} already exists.")
response = input("Do you want to overwrite the existing run path? (y/n): ")
if response.lower() != "y":
exit()
else:
shutil.rmtree(run_path)
start_rank = int(os.environ.get("START_RANK", 0))
rank = start_rank + int(os.environ.get("RANK", 0))

if rank != 0 or not run_path.exists():
return

if overwrite:
shutil.rmtree(run_path)
else:
raise FileExistsError(
f"Run path {run_path} already exists. Use --overwrite to overwrite it."
)


@dataclass
Expand All @@ -36,8 +42,8 @@ def execute(self):

run_path = Path(self.index_cfg.run_path)
partial_run_path = Path(self.index_cfg.partial_run_path)
validate_run_path(run_path)
validate_run_path(partial_run_path)
validate_run_path(run_path, self.index_cfg.overwrite)
validate_run_path(partial_run_path, self.index_cfg.overwrite)

build(self.index_cfg)

Expand All @@ -52,16 +58,11 @@ class Reduce:

def execute(self):
"""Reduce a gradient index."""
if self.index_cfg.projection_dim != 0:
print(
"Warning: projection_dim is not 0. "
"Compressed gradients will be reduced."
)

run_path = Path(self.index_cfg.run_path)
partial_run_path = Path(self.index_cfg.partial_run_path)
validate_run_path(run_path)
validate_run_path(partial_run_path)

validate_run_path(run_path, self.index_cfg.overwrite)
validate_run_path(partial_run_path, self.index_cfg.overwrite)

reduce(self.index_cfg, self.reduce_cfg)

Expand Down
13 changes: 8 additions & 5 deletions bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

def build_worker(
rank: int,
local_rank: int,
world_size: int,
cfg: IndexConfig,
ds: Dataset | IterableDataset,
Expand All @@ -39,7 +40,7 @@ def build_worker(
ds : Dataset | IterableDataset
The entire dataset to be indexed. A subset is assigned to each worker.
"""
torch.cuda.set_device(rank)
torch.cuda.set_device(local_rank)

# These should be set by the main process
if world_size > 1:
Expand All @@ -49,14 +50,14 @@ def build_worker(
dist.init_process_group(
"nccl",
init_method=f"tcp://{addr}:{port}",
device_id=torch.device(f"cuda:{rank}"),
device_id=torch.device(f"cuda:{local_rank}"),
rank=rank,
timeout=timedelta(hours=1),
world_size=world_size,
)

model, target_modules = setup_model_and_peft(cfg, rank)
processor = create_processor(cfg, rank)
model, target_modules = setup_model_and_peft(cfg, local_rank)
processor = create_processor(cfg, local_rank, rank)

attention_cfgs = {module: cfg.attention for module in cfg.split_attention_modules}

Expand Down Expand Up @@ -118,4 +119,6 @@ def build(index_cfg: IndexConfig):

launch_distributed_run("build", build_worker, [index_cfg, ds])

shutil.move(index_cfg.partial_run_path, index_cfg.run_path)
rank = int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", 0)))
if rank == 0:
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)
Loading