diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 14a270defd..fe691be4e3 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -292,6 +292,7 @@ def __init__(self, config, vecenv, policy, logger=None): self.vecenv = vecenv self.epoch = 0 self.global_step = 0 + self.agent_steps = 0 self.last_log_step = 0 self.last_log_time = time.time() self.start_time = time.time() @@ -548,7 +549,7 @@ def train(self): policy=self.uncompiled_policy, env_name=self.config["env"], logger=self.logger, - global_step=self.global_step, + global_step=self.agent_steps, ) return logs @@ -814,6 +815,7 @@ def mean_and_log(self): device = config["device"] agent_steps = int(dist_sum(self.global_step, device)) + self.agent_steps = agent_steps logs = { "SPS": dist_sum(self.sps, device), "agent_steps": agent_steps, @@ -1437,17 +1439,22 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop model.forward_eval = policy.forward_eval policy = model.to(local_rank) - if args["neptune"]: - logger = NeptuneLogger(args) - elif args["wandb"]: - logger = WandbLogger(args) - elif args["tb"]: - date_time = datetime.now().strftime("%Y%m%d-%H%M%S") - experiment_dir = os.path.join(args["train"]["data_dir"], rf"{env_name}_" + date_time) - logger = TensorBoardLogger( - run_id=date_time, - experiment_dir=experiment_dir, - ) + # Under DDP only rank 0 owns the run logger; other ranks keep logger=None, + # which PuffeRL wraps in a NoLogger. Without this gate every rank calls + # wandb.init()/NeptuneLogger and you get world_size duplicate runs. + is_rank0 = (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0 + if is_rank0: + if args["neptune"]: + logger = NeptuneLogger(args) + elif args["wandb"]: + logger = WandbLogger(args) + elif args["tb"]: + date_time = datetime.now().strftime("%Y%m%d-%H%M%S") + experiment_dir = os.path.join(args["train"]["data_dir"], rf"{env_name}_" + date_time) + logger = TensorBoardLogger( + run_id=date_time, + experiment_dir=experiment_dir, + ) train_config = dict(**args["train"], env=env_name, eval=args.get("eval", {})) pufferl = PuffeRL(train_config, vecenv, policy, logger) @@ -1546,7 +1553,7 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop policy=pufferl.uncompiled_policy, env_name=pufferl.config["env"], logger=pufferl.logger, - global_step=pufferl.global_step, + global_step=pufferl.agent_steps, force=True, )