From 5c925b9163413628134c70a410cbdef7570c4bb7 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Mon, 7 Apr 2025 15:47:13 +0000 Subject: [PATCH 01/26] early buffer prototype --- c_advantage.pyx | 48 +++++++++------ clean_pufferl.py | 139 ++++++++++++++++++++++++++------------------ demo.py | 48 +-------------- pufferlib/models.py | 24 ++++---- 4 files changed, 126 insertions(+), 133 deletions(-) diff --git a/c_advantage.pyx b/c_advantage.pyx index 4393899e7..a07886727 100644 --- a/c_advantage.pyx +++ b/c_advantage.pyx @@ -90,25 +90,35 @@ def fast_rewards_and_masks(float[:, :] reward_block, float[:, :] reward_mask, memcpy(&reward_block[i, 0], &rewards[i+1], h * sizeof(float)) -def compute_gae(cnp.ndarray dones, cnp.ndarray values, - cnp.ndarray rewards, float gamma, float gae_lambda): +def compute_gae(cnp.ndarray dones, float[:, :] values, + float[:, :] rewards, int[:] stored_idxs, + float gamma, float gae_lambda): '''Fast Cython implementation of Generalized Advantage Estimation (GAE)''' - cdef int num_steps = len(rewards) - cdef cnp.ndarray advantages = np.zeros(num_steps, dtype=np.float32) - cdef float[:] c_advantages = advantages - cdef float[:] c_dones = dones - cdef float[:] c_values = values - cdef float[:] c_rewards = rewards - - cdef float lastgaelam = 0 - cdef float nextnonterminal, delta - cdef int t, t_cur, t_next - for t in range(num_steps-1): - t_cur = num_steps - 2 - t - t_next = num_steps - 1 - t - nextnonterminal = 1.0 - c_dones[t_next] - delta = c_rewards[t_next] + gamma * c_values[t_next] * nextnonterminal - c_values[t_cur] - lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam - c_advantages[t_cur] = lastgaelam + cdef: + float[:, :] c_dones = dones + int num_rows = dones.shape[0] + int horizon = dones.shape[1] + float lastgaelam = 0 + float nextnonterminal, delta + int t, t_cur, t_next + cnp.ndarray advantages = np.zeros((num_rows, horizon), dtype=np.float32) + cnp.ndarray ep_adv = np.zeros(np.max(stored_idxs)+1, dtype=np.float32) + + cdef: + float[:, :] c_advantages = advantages + float[:] c_ep_adv = ep_adv + int agent_id + + for row in range(num_rows-1, -1, -1): + agent_id = stored_idxs[row] + lastgaelam = ep_adv[agent_id] + for t in range(horizon-2, -1, -1): + t_next = t + 1 + nextnonterminal = 1.0 - c_dones[row, t_next] + delta = rewards[row, t_next] + gamma*values[row, t_next]*nextnonterminal - values[row, t] + lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam + c_advantages[row, t] = lastgaelam + + c_ep_adv[agent_id] = lastgaelam return advantages diff --git a/clean_pufferl.py b/clean_pufferl.py index 211702ef4..8e57e9ed5 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -31,17 +31,19 @@ # Compile the CUDA kernel +''' cuda_module = load( name='advantage_kernel', sources=['pufferlib.cu'], verbose=True ) +''' def compute_advantages( - reward_block: torch.Tensor, # [num_steps, horizon] - reward_mask: torch.Tensor, # [num_steps, horizon] - values_mean: torch.Tensor, # [num_steps, horizon] - values_std: torch.Tensor, # [num_steps, horizon] + reward_block: torch.Tensor, # [num_steps, horizon] + reward_mask: torch.Tensor, # [num_steps, horizon] + values_mean: torch.Tensor, # [num_steps, horizon] + values_std: torch.Tensor, # [num_steps, horizon] buf: torch.Tensor, # [num_steps, horizon] dones: torch.Tensor, # [num_steps] rewards: torch.Tensor, # [num_steps] @@ -103,10 +105,9 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): atn_dtype = vecenv.single_action_space.dtype total_agents = vecenv.num_agents - lstm = policy.recurrent if hasattr(policy, 'recurrent') else None experience = Experience(config.batch_size, config.bptt_horizon, config.minibatch_size, config.max_minibatch_size, policy.hidden_size, obs_shape, obs_dtype, - atn_shape, atn_dtype, config.cpu_offload, config.device, lstm, total_agents, + atn_shape, atn_dtype, config.cpu_offload, config.device, policy, total_agents, use_e3b=config.use_e3b, e3b_coef=config.e3b_coef, e3b_lambda=config.e3b_lambda, use_diayn=config.use_diayn, diayn_archive=config.diayn_archive, diayn_coef=config.diayn_coef, use_p3o=config.use_p3o, p3o_horizon=config.p3o_horizon @@ -308,7 +309,7 @@ def evaluate(data): with profile.eval_copy: o = o if config.cpu_offload else o_device - actions = experience.store(state, o, o_device, value, action, logprob, r, d, env_id, mask) + actions = experience.store(state, o, o_device, value, action, logprob, r, d, gpu_env_id, mask) if config.device == 'cuda': torch.cuda.synchronize() @@ -352,9 +353,11 @@ def train(data): losses = data.losses with profile.train_copy: - idxs = experience.sort_training_data() - dones = experience.dones[idxs] - rewards = experience.rewards[idxs] + #idxs = experience.sort_training_data() + #dones = experience.dones[idxs] + #rewards = experience.rewards[idxs] + dones = experience.dones + rewards = experience.rewards with profile.train_misc: if config.use_p3o: @@ -404,12 +407,14 @@ def train(data): experience.flatten_batch(advantages, reward_block, mask_block) torch.cuda.synchronize() else: - values_np = experience.values[idxs].to('cpu', non_blocking=True).numpy() + #values_np = experience.values[idxs].to('cpu', non_blocking=True).numpy() + values_np = experience.values.to('cpu', non_blocking=True).numpy() dones_np = dones.to('cpu', non_blocking=True).numpy() rewards_np = rewards.to('cpu', non_blocking=True).numpy() + stored_idxs = experience.stored_indices.to('cpu', non_blocking=True).numpy() torch.cuda.synchronize() advantages_np = compute_gae(dones_np, values_np, - rewards_np, config.gamma, config.gae_lambda) + rewards_np, stored_idxs, config.gamma, config.gae_lambda) experience.flatten_batch(advantages_np) # Optimizing the policy and value network @@ -444,14 +449,14 @@ def train(data): rew_block = experience.b_reward_block[mb] mask_block = experience.b_mask_block[mb] else: - val = experience.b_values[mb] + val = experience.b_values[mb].flatten() if config.device == 'cuda': torch.cuda.synchronize() with data.amp_context: with profile.train_forward: - if not hasattr(data.policy, 'recurrent'): + if not isinstance(data.policy, torch.nn.LSTM): obs = obs.reshape(-1, *data.vecenv.single_observation_space.shape) logits, newvalue = data.policy.forward_train(obs, state) @@ -517,6 +522,7 @@ def train(data): #v_loss = v_loss[mask_block.bool()].mean() elif config.clip_vloss: newvalue = newvalue.flatten() + ret = ret.flatten() v_loss_unclipped = (newvalue - ret) ** 2 v_clipped = val + torch.clamp( newvalue - val, @@ -593,8 +599,8 @@ def train(data): y_pred = experience.values_mean y_true = experience.reward_block else: - y_pred = experience.values - y_true = experience.returns + y_pred = experience.values.flatten() + y_true = experience.b_returns.flatten() var_y = y_true.var() explained_var = torch.nan if var_y == 0 else 1 - (y_true - y_pred).var() / var_y @@ -806,24 +812,35 @@ class Experience: '''Flat tensor storage and array views for faster indexing''' def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, hidden_size, obs_shape, obs_dtype, atn_shape, atn_dtype, cpu_offload=False, - device='cuda', lstm=None, lstm_total_agents=0, + device='cuda', policy=None, lstm_total_agents=0, use_e3b=False, e3b_coef=0.1, e3b_lambda=10.0, use_diayn=False, diayn_archive=128, diayn_coef=0.1, use_p3o=False, p3o_horizon=32): if minibatch_size is None: minibatch_size = batch_size + num_rows = batch_size // bptt_horizon + self.num_rows = num_rows + obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_dtype] atn_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[atn_dtype] pin = device == 'cuda' and cpu_offload obs_device = device if not pin else 'cpu' - self.obs=torch.zeros(batch_size, *obs_shape, dtype=obs_dtype, + self.obs_shape = obs_shape + self.atn_shape = atn_shape + self.obs=torch.zeros(num_rows, bptt_horizon, *obs_shape, dtype=obs_dtype, pin_memory=pin, device=device if not pin else 'cpu') - self.actions=torch.zeros(batch_size, *atn_shape, dtype=atn_dtype, device=device) - self.logprobs=torch.zeros(batch_size, device=device) - self.rewards=torch.zeros(batch_size, device=device) - self.dones=torch.zeros(batch_size, device=device) - self.truncateds=torch.zeros(batch_size, device=device) + self.actions=torch.zeros(num_rows, bptt_horizon, *atn_shape, + dtype=atn_dtype, device=device) + self.logprobs=torch.zeros(num_rows, bptt_horizon, device=device) + self.rewards=torch.zeros(num_rows, bptt_horizon, device=device) + self.dones=torch.zeros(num_rows, bptt_horizon, device=device) + self.truncateds=torch.zeros(num_rows, bptt_horizon, device=device) + + self.ep_lengths = torch.zeros(lstm_total_agents, device=device, dtype=torch.int32) + self.ep_indices = torch.arange(lstm_total_agents, device=device, dtype=torch.int32) + self.stored_indices = torch.zeros(num_rows, device=device, dtype=torch.int32) + self.free_idx = 0 self.use_e3b = use_e3b if use_e3b: @@ -851,15 +868,15 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, self.bounds = torch.zeros(batch_size, dtype=torch.int32, device=device) self.vstd_max = 1.0 else: - self.values = torch.zeros(batch_size, device=device) + self.values = torch.zeros(num_rows, bptt_horizon, device=device) self.sort_keys = np.zeros((batch_size, 3), dtype=np.int32) self.sort_keys[:, 0] = np.arange(batch_size) self.lstm_h = self.lstm_c = None - if lstm is not None: + if isinstance(policy, torch.nn.LSTM): assert lstm_total_agents > 0 - shape = (lstm.num_layers, lstm_total_agents, lstm.hidden_size) + shape = (policy.num_layers, lstm_total_agents, policy.hidden_size) self.lstm_h = torch.zeros(shape).to(device) self.lstm_c = torch.zeros(shape).to(device) @@ -884,7 +901,7 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, @property def full(self): - return self.ptr >= self.batch_size + return self.free_idx >= self.num_rows def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, env_id, mask): # Mask learner and Ensure indices do not exceed batch size @@ -895,16 +912,36 @@ def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, e dst = slice(ptr, end) # Zero-copy indexing for contiguous env_id + ''' if num_indices == mask.size and isinstance(env_id, slice): gpu_inds = cpu_inds = slice(0, min(self.batch_size - ptr, num_indices)) else: cpu_inds = indices[:self.batch_size - ptr] gpu_inds = torch.as_tensor(cpu_inds).to(self.obs.device, non_blocking=True) + ''' + + batch_rows = self.ep_indices[env_id] + l = self.ep_lengths[env_id] if self.obs.device.type == 'cuda': - self.obs[dst] = gpu_obs[gpu_inds] + self.obs[batch_rows, l] = gpu_obs else: - self.obs[dst] = cpu_obs[cpu_inds] + self.obs[batch_rows, l] = cpu_obs + + if isinstance(env_id, slice): + self.stored_indices[batch_rows] = torch.arange(env_id.stop - env_id.start, device=self.device).int() + env_id.start + else: + self.stored_indices[batch_rows] = env_id + + l += 1 + self.ep_lengths[env_id] = l + full = l >= self.bptt_horizon + num_full = full.sum() + if num_full > 0: + self.ep_lengths[full] = 0 + self.ep_indices[full] = self.free_idx + torch.arange(num_full, device=self.device).int() + self.free_idx += num_full + if self.use_diayn: self.diayn_batch[dst] = state.diayn_z_idxs[gpu_inds] @@ -913,22 +950,13 @@ def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, e self.values_mean[dst] = value.mean[gpu_inds] self.values_std[dst] = value.std[gpu_inds] else: - self.values[dst] = value[gpu_inds].flatten() + self.values[batch_rows, l] = value.flatten() - self.actions[dst] = action[gpu_inds] - self.logprobs[dst] = logprob[gpu_inds] - self.rewards[dst] = reward[cpu_inds].to(self.rewards.device) # ??? - self.dones[dst] = done[cpu_inds].to(self.dones.device) # ??? - - if isinstance(env_id, slice): - self.sort_keys[dst, 1] = np.arange(env_id.start, env_id.stop, dtype=np.int32) - else: - self.sort_keys[dst, 1] = env_id[cpu_inds] - - self.sort_keys[dst, 2] = self.step - self.ptr = end + self.actions[batch_rows, l] = action + self.logprobs[batch_rows, l] = logprob + self.rewards[batch_rows, l] = reward.to(self.rewards.device) # ??? + self.dones[batch_rows, l] = done.float().to(self.dones.device) # ??? self.step += 1 - return action.cpu().numpy() def sort_training_data(self): @@ -943,16 +971,18 @@ def sort_training_data(self): return idxs def flatten_batch(self, advantages_np, reward_block=None, mask_block=None): + self.free_idx = 0 advantages = torch.as_tensor(advantages_np).to(self.device, non_blocking=True) self.b_advantages = advantages.reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon - ).transpose(0, 1).reshape(self.num_minibatches, self.minibatch_size) - - b_idxs, b_flat = self.b_idxs, self.b_idxs_flat - self.b_actions = self.actions.to(self.device, non_blocking=True)[b_idxs].contiguous() - self.b_logprobs = self.logprobs.to(self.device, non_blocking=True)[b_idxs] - self.b_dones = self.dones.to(self.device, non_blocking=True)[b_idxs] - self.b_obs = self.obs[self.b_idxs_obs] + self.num_minibatches, self.minibatch_rows, self.bptt_horizon) + self.b_actions = self.actions.to(self.device, non_blocking=True).reshape( + self.num_minibatches, self.minibatch_rows, self.bptt_horizon, -1) + self.b_logprobs = self.logprobs.to(self.device, non_blocking=True).reshape( + self.num_minibatches, self.minibatch_rows, self.bptt_horizon) + self.b_dones = self.dones.to(self.device, non_blocking=True).reshape( + self.num_minibatches, self.minibatch_rows, self.bptt_horizon) + self.b_obs = self.obs.to(self.device, non_blocking=True).reshape( + self.num_minibatches, self.minibatch_rows, self.bptt_horizon, *self.obs_shape) if self.use_p3o: self.reward_block = torch.as_tensor(reward_block).to(self.device) @@ -971,10 +1001,9 @@ def flatten_batch(self, advantages_np, reward_block=None, mask_block=None): self.minibatch_rows, self.num_minibatches, self.bptt_horizon, self.p3o_horizon ).transpose(0, 1).reshape(self.num_minibatches, self.minibatch_size, self.p3o_horizon) else: - self.b_values = self.values.to(self.device, non_blocking=True)[b_flat] - self.returns = advantages + self.values # Check sorting of values here - self.b_returns = self.b_advantages + self.b_values # Check sorting of values here - + self.b_values = self.values.to(self.device, non_blocking=True).reshape( + self.num_minibatches, self.minibatch_rows, self.bptt_horizon) + self.b_returns = self.b_advantages + self.b_values if self.use_diayn: self.b_diayn_z_idxs = self.diayn_batch.to(self.device, non_blocking=True)[b_flat] self.b_diayn_z = self.diayn_archive[self.b_diayn_z_idxs] diff --git a/demo.py b/demo.py index 4e39d52a4..6c374f3d7 100644 --- a/demo.py +++ b/demo.py @@ -142,16 +142,10 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10 from torch.nn.parallel import DistributedDataParallel as DDP orig_policy = policy policy = DDP(policy, device_ids=[args['rank']]) + # TODO: Test this? isinstance? if hasattr(orig_policy, 'lstm'): policy.lstm = orig_policy.lstm - ''' - if env_name == 'moba': - import torch - os.makedirs('moba_elo', exist_ok=True) - torch.save(policy, os.path.join('moba_elo', 'model_random.pt')) - ''' - neptune = None wandb = None if args['neptune']: @@ -170,38 +164,9 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10 costs = [] target_key = f'environment/{target_metric}' - ''' - from torch.profiler import profile, record_function, ProfilerActivity - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU] - from torch.profiler import schedule - prof_schedule = schedule( - skip_first=10, - wait=5, - warmup=1, - active=3, - repeat=2 - ) - - sort_by_keyword = "self_" + args['train']['device'] + "_time_total" - - def trace_handler(p): - output = p.key_averages().table(sort_by=sort_by_keyword, row_limit=10) - print(output) - p.export_chrome_trace("trace/trace_" + str(p.step_num) + ".json") - - with profile( - activities=activities, - schedule=torch.profiler.schedule( - wait=1, - warmup=1, - active=2), - on_trace_ready=trace_handler - ) as p: - ''' while data.global_step < train_config.total_timesteps: clean_pufferl.evaluate(data) logs = clean_pufferl.train(data) - #p.step() if logs is not None and target_key in logs: timesteps.append(logs['agent_steps']) scores.append(logs[target_key]) @@ -239,17 +204,6 @@ def downsample_linear(arr, m): elif args['wandb']: wandb.log({'score': score, 'cost': cost}) - ''' - if env_name == 'moba': - exp_n = len(elos) - model_name = f'model_{exp_n}.pt' - torch.save(policy, os.path.join('moba_elo', model_name)) - from evaluate_elos import calc_elo - elos = calc_elo(model_name, 'moba_elo', elos) - stats['elo'] = elos[model_name] - if wandb is not None: - wandb.log({'environment/elo': elos[model_name]}) - ''' clean_pufferl.close(data) return scores, costs, timesteps, elos, vecenv diff --git a/pufferlib/models.py b/pufferlib/models.py index d1c610f96..e95bf19a4 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -125,33 +125,33 @@ def decode_actions(self, hidden): return logits, values -class LSTMWrapper(nn.Module): +class LSTMWrapper(nn.LSTM): def __init__(self, env, policy, input_size=128, hidden_size=128): '''Wraps your policy with an LSTM without letting you shoot yourself in the foot with bad transpose and shape operations. This saves much pain. Requires that your policy define encode_observations and decode_actions. See the Default policy for an example.''' - super().__init__() + super().__init__(input_size, hidden_size) self.obs_shape = env.single_observation_space.shape self.policy = policy self.input_size = input_size self.hidden_size = hidden_size - self.recurrent = nn.LSTM(input_size, hidden_size) - self.recurrent_cell = torch.nn.LSTMCell(input_size, hidden_size) - self.recurrent_cell.weight_ih = self.recurrent.weight_ih_l0 - self.recurrent_cell.weight_hh = self.recurrent.weight_hh_l0 - self.recurrent_cell.bias_ih = self.recurrent.bias_ih_l0 - self.recurrent_cell.bias_hh = self.recurrent.bias_hh_l0 - self.is_continuous = self.policy.is_continuous - for name, param in self.recurrent.named_parameters(): + for name, param in self.named_parameters(): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name: nn.init.orthogonal_(param, 1.0) + self.cell = torch.nn.LSTMCell(input_size, hidden_size) + self.cell.weight_ih = self.weight_ih_l0 + self.cell.weight_hh = self.weight_hh_l0 + self.cell.bias_ih = self.bias_ih_l0 + self.cell.bias_hh = self.bias_hh_l0 + + def forward(self, observations, state): '''Forward function for inference. 3x faster than using LSTM directly''' hidden = self.policy.encode_observations(observations) @@ -165,7 +165,7 @@ def forward(self, observations, state): else: lstm_state = None - hidden, c = self.recurrent_cell(hidden, lstm_state) + hidden, c = self.cell(hidden, lstm_state) state.hidden = hidden state.lstm_h = hidden state.lstm_c = c @@ -203,7 +203,7 @@ def forward_train(self, observations, state): hidden = hidden.reshape(B, TT, self.input_size) hidden = hidden.transpose(0, 1) - hidden, (lstm_h, lstm_c)= self.recurrent(hidden, lstm_state) + hidden, (lstm_h, lstm_c) = super().forward(hidden, lstm_state) hidden = hidden.transpose(0, 1) hidden = hidden.reshape(B*TT, self.hidden_size) From b569d497779bca3e8c3a2af2e66d5e1132535da9 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 8 Apr 2025 17:22:59 +0000 Subject: [PATCH 02/26] Initial messy prio exp --- c_advantage.pyx | 13 +- clean_pufferl.py | 399 ++++++++++++++++++++------------------ config/default.ini | 1 + config/ocean/breakout.ini | 39 ++-- demo.py | 4 +- 5 files changed, 242 insertions(+), 214 deletions(-) diff --git a/c_advantage.pyx b/c_advantage.pyx index a07886727..7b2b071fb 100644 --- a/c_advantage.pyx +++ b/c_advantage.pyx @@ -91,8 +91,7 @@ def fast_rewards_and_masks(float[:, :] reward_block, float[:, :] reward_mask, memcpy(&reward_block[i, 0], &rewards[i+1], h * sizeof(float)) def compute_gae(cnp.ndarray dones, float[:, :] values, - float[:, :] rewards, int[:] stored_idxs, - float gamma, float gae_lambda): + float[:, :] rewards, float gamma, float gae_lambda): '''Fast Cython implementation of Generalized Advantage Estimation (GAE)''' cdef: float[:, :] c_dones = dones @@ -102,16 +101,10 @@ def compute_gae(cnp.ndarray dones, float[:, :] values, float nextnonterminal, delta int t, t_cur, t_next cnp.ndarray advantages = np.zeros((num_rows, horizon), dtype=np.float32) - cnp.ndarray ep_adv = np.zeros(np.max(stored_idxs)+1, dtype=np.float32) - - cdef: float[:, :] c_advantages = advantages - float[:] c_ep_adv = ep_adv - int agent_id for row in range(num_rows-1, -1, -1): - agent_id = stored_idxs[row] - lastgaelam = ep_adv[agent_id] + lastgaelam = 0 for t in range(horizon-2, -1, -1): t_next = t + 1 nextnonterminal = 1.0 - c_dones[row, t_next] @@ -119,6 +112,4 @@ def compute_gae(cnp.ndarray dones, float[:, :] values, lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam c_advantages[row, t] = lastgaelam - c_ep_adv[agent_id] = lastgaelam - return advantages diff --git a/clean_pufferl.py b/clean_pufferl.py index 8e57e9ed5..bdfc8b6be 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -107,7 +107,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): experience = Experience(config.batch_size, config.bptt_horizon, config.minibatch_size, config.max_minibatch_size, policy.hidden_size, obs_shape, obs_dtype, - atn_shape, atn_dtype, config.cpu_offload, config.device, policy, total_agents, + atn_shape, atn_dtype, config.cpu_offload, config.device, policy, total_agents, config.replay_factor, use_e3b=config.use_e3b, e3b_coef=config.e3b_coef, e3b_lambda=config.e3b_lambda, use_diayn=config.use_diayn, diayn_archive=config.diayn_archive, diayn_coef=config.diayn_coef, use_p3o=config.use_p3o, p3o_horizon=config.p3o_horizon @@ -359,6 +359,10 @@ def train(data): dones = experience.dones rewards = experience.rewards + # TODO: Beter place for this + experience.free_idx = 0 + experience.ep_lengths.zero_() + with profile.train_misc: if config.use_p3o: reward_block = experience.reward_block @@ -406,16 +410,6 @@ def train(data): experience.flatten_batch(advantages, reward_block, mask_block) torch.cuda.synchronize() - else: - #values_np = experience.values[idxs].to('cpu', non_blocking=True).numpy() - values_np = experience.values.to('cpu', non_blocking=True).numpy() - dones_np = dones.to('cpu', non_blocking=True).numpy() - rewards_np = rewards.to('cpu', non_blocking=True).numpy() - stored_idxs = experience.stored_indices.to('cpu', non_blocking=True).numpy() - torch.cuda.synchronize() - advantages_np = compute_gae(dones_np, values_np, - rewards_np, stored_idxs, config.gamma, config.gae_lambda) - experience.flatten_batch(advantages_np) # Optimizing the policy and value network total_minibatches = experience.num_minibatches * config.update_epochs @@ -424,168 +418,191 @@ def train(data): cross_entropy = torch.nn.CrossEntropyLoss() accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) for epoch in range(config.update_epochs): - lstm_h = None - lstm_c = None - for mb in range(experience.num_minibatches): - with profile.train_misc: - state = pufferlib.namespace( - action=experience.b_actions[mb], - lstm_h=lstm_h, - lstm_c=lstm_c, - ) - obs = experience.b_obs[mb] - obs = obs.to(config.device) - atn = experience.b_actions[mb] - log_probs = experience.b_logprobs[mb] - adv = experience.b_advantages[mb] - ret = experience.b_returns[mb] + values_np = experience.values.to('cpu', non_blocking=True).numpy() + dones_np = dones.to('cpu', non_blocking=True).numpy() + rewards_np = rewards.to('cpu', non_blocking=True).numpy() + stored_idxs = experience.stored_indices.to('cpu', non_blocking=True).numpy() + torch.cuda.synchronize() + advantages_np = compute_gae(dones_np, values_np, rewards_np, config.gamma, config.gae_lambda) + advantages = torch.as_tensor(advantages_np).to(config.device, non_blocking=True) + n_samples = config.minibatch_size // config.bptt_horizon + exp = experience.sample(advantages, n_samples) + + obs = exp.obs + atn = exp.actions + log_probs = exp.logprobs + adv = exp.advantages + ret = exp.returns + + with profile.train_misc: + state = pufferlib.namespace( + action=atn, + lstm_h=None, + lstm_c=None, + ) + if config.use_diayn: + z_idxs = experience.b_diayn_z_idxs[mb] + + if config.use_p3o: + val_mean = experience.b_values_mean[mb] + val_std = experience.b_values_std[mb] + rew_block = experience.b_reward_block[mb] + mask_block = experience.b_mask_block[mb] + else: + val = exp.values.flatten() - if config.use_diayn: - z_idxs = experience.b_diayn_z_idxs[mb] + if config.device == 'cuda': + torch.cuda.synchronize() - if config.use_p3o: - val_mean = experience.b_values_mean[mb] - val_std = experience.b_values_std[mb] - rew_block = experience.b_reward_block[mb] - mask_block = experience.b_mask_block[mb] - else: - val = experience.b_values[mb].flatten() + with data.amp_context: + with profile.train_forward: + if not isinstance(data.policy, torch.nn.LSTM): + obs = obs.reshape(-1, *data.vecenv.single_observation_space.shape) - if config.device == 'cuda': - torch.cuda.synchronize() + logits, newvalue = data.policy.forward_train(obs, state) + lstm_h = state.lstm_h + lstm_c = state.lstm_c + if lstm_h is not None: + lstm_h = lstm_h.detach() + if lstm_c is not None: + lstm_c = lstm_c.detach() - with data.amp_context: - with profile.train_forward: - if not isinstance(data.policy, torch.nn.LSTM): - obs = obs.reshape(-1, *data.vecenv.single_observation_space.shape) + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, + action=atn, is_continuous=data.policy.is_continuous) - logits, newvalue = data.policy.forward_train(obs, state) - lstm_h = state.lstm_h - lstm_c = state.lstm_c - if lstm_h is not None: - lstm_h = lstm_h.detach() - if lstm_c is not None: - lstm_c = lstm_c.detach() + if config.device == 'cuda': + torch.cuda.synchronize() - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, - action=atn, is_continuous=data.policy.is_continuous) + with profile.train_misc: + logratio = newlogprob - log_probs.reshape(-1) + ratio = logratio.exp() - if config.device == 'cuda': - torch.cuda.synchronize() + # TODO: Only do this if we are KL clipping? Saves 1-2% compute + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfrac = ((ratio - 1.0).abs() > config.clip_coef).float().mean() + + adv = adv.reshape(-1) + if config.norm_adv: + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Policy loss + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp( + ratio, 1 - config.clip_coef, 1 + config.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() - with profile.train_misc: - logratio = newlogprob - log_probs.reshape(-1) - ratio = logratio.exp() - - # TODO: Only do this if we are KL clipping? Saves 1-2% compute - with torch.no_grad(): - # calculate approx_kl http://joschu.net/blog/kl-approx.html - old_approx_kl = (-logratio).mean() - approx_kl = ((ratio - 1) - logratio).mean() - clipfrac = ((ratio - 1.0).abs() > config.clip_coef).float().mean() - - adv = adv.reshape(-1) - if config.norm_adv: - adv = (adv - adv.mean()) / (adv.std() + 1e-8) - - # Policy loss - pg_loss1 = -adv * ratio - pg_loss2 = -adv * torch.clamp( - ratio, 1 - config.clip_coef, 1 + config.clip_coef + # Value loss + if config.use_p3o: + newvalue_mean = newvalue.mean.view(-1, config.p3o_horizon) + newvalue_std = newvalue.std.view(-1, config.p3o_horizon) + newvalue_var = torch.square(newvalue_std) + criterion = torch.nn.GaussianNLLLoss(reduction='none') + #v_loss = criterion(newvalue_mean[:, :32], rew_block[:, :32], newvalue_var[:, :32]) + v_loss = criterion(newvalue_mean, rew_block, newvalue_var) + v_loss = v_loss[:, :(horizon+3)] + mask_block = mask_block[:, :(horizon+3)] + #v_loss[:, horizon:] = 0 + #v_loss = (v_loss * mask_block).sum(axis=1) + #v_loss = (v_loss - v_loss.mean().item()) / (v_loss.std().item() + 1e-8) + #v_loss = v_loss.mean() + v_loss = v_loss[mask_block.bool()].mean() + #TODO: Count mask and sum + # There is going to have to be some sort of norm here. + # Right now, learning works at different horizons, but you need + # to retune hyperparameters. Ideally, horizon should be a stable + # param that zero-shots the same hypers + + # Faster than masking + #v_loss = (v_loss*mask_block[:, :32]).sum() / mask_block[:, :32].sum() + #v_loss = (v_loss*mask_block).sum() / mask_block.sum() + #v_loss = v_loss[mask_block.bool()].mean() + elif config.clip_vloss: + newvalue = newvalue.flatten() + ret = ret.flatten() + v_loss_unclipped = (newvalue - ret) ** 2 + v_clipped = val + torch.clamp( + newvalue - val, + -config.vf_clip_coef, + config.vf_clip_coef, ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() - - # Value loss - if config.use_p3o: - newvalue_mean = newvalue.mean.view(-1, config.p3o_horizon) - newvalue_std = newvalue.std.view(-1, config.p3o_horizon) - newvalue_var = torch.square(newvalue_std) - criterion = torch.nn.GaussianNLLLoss(reduction='none') - #v_loss = criterion(newvalue_mean[:, :32], rew_block[:, :32], newvalue_var[:, :32]) - v_loss = criterion(newvalue_mean, rew_block, newvalue_var) - v_loss = v_loss[:, :(horizon+3)] - mask_block = mask_block[:, :(horizon+3)] - #v_loss[:, horizon:] = 0 - #v_loss = (v_loss * mask_block).sum(axis=1) - #v_loss = (v_loss - v_loss.mean().item()) / (v_loss.std().item() + 1e-8) - #v_loss = v_loss.mean() - v_loss = v_loss[mask_block.bool()].mean() - #TODO: Count mask and sum - # There is going to have to be some sort of norm here. - # Right now, learning works at different horizons, but you need - # to retune hyperparameters. Ideally, horizon should be a stable - # param that zero-shots the same hypers - - # Faster than masking - #v_loss = (v_loss*mask_block[:, :32]).sum() / mask_block[:, :32].sum() - #v_loss = (v_loss*mask_block).sum() / mask_block.sum() - #v_loss = v_loss[mask_block.bool()].mean() - elif config.clip_vloss: - newvalue = newvalue.flatten() - ret = ret.flatten() - v_loss_unclipped = (newvalue - ret) ** 2 - v_clipped = val + torch.clamp( - newvalue - val, - -config.vf_clip_coef, - config.vf_clip_coef, - ) - v_loss_clipped = (v_clipped - ret) ** 2 - v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) - v_loss = 0.5 * v_loss_max.mean() - else: - newvalue = newvalue.flatten() - v_loss = 0.5 * ((newvalue - ret) ** 2).mean() - - entropy_loss = entropy.mean() - loss = pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef - - with profile.custom: - if config.use_diayn: - diayn_discriminator = data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator - q = diayn_discriminator(state.hidden).squeeze() - diayn_loss = cross_entropy(q, z_idxs) - loss += config.diayn_loss_coef*diayn_loss - torch.cuda.synchronize() - - with profile.learn: - if data.scaler is None: - loss.backward() + v_loss_clipped = (v_clipped - ret) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() else: - data.scaler.scale(loss).backward() + newvalue = newvalue.flatten() + v_loss = 0.5 * ((newvalue - ret) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef + + with profile.custom: + if config.use_diayn: + diayn_discriminator = data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator + q = diayn_discriminator(state.hidden).squeeze() + diayn_loss = cross_entropy(q, z_idxs) + loss += config.diayn_loss_coef*diayn_loss + torch.cuda.synchronize() - if data.scaler is not None: - data.scaler.unscale_(data.optimizer) + with profile.learn: + if data.scaler is None: + loss.backward() + else: + data.scaler.scale(loss).backward() - with torch.no_grad(): - grads = torch.cat([p.grad.flatten() for p in data.policy.parameters()]) - grad_var = grads.var(0).mean() * config.minibatch_size - data.msg = f'Gradient variance: {grad_var.item():.3f}' + if data.scaler is not None: + data.scaler.unscale_(data.optimizer) - if (mb + 1) % accumulate_minibatches == 0: - torch.nn.utils.clip_grad_norm_(data.policy.parameters(), config.max_grad_norm) + with torch.no_grad(): + grads = torch.cat([p.grad.flatten() for p in data.policy.parameters()]) + grad_var = grads.var(0).mean() * config.minibatch_size + data.msg = f'Gradient variance: {grad_var.item():.3f}' - if data.scaler is None: - data.optimizer.step() - else: - data.scaler.step(data.optimizer) - data.scaler.update() + if (epoch + 1) % accumulate_minibatches == 0: + torch.nn.utils.clip_grad_norm_(data.policy.parameters(), config.max_grad_norm) - data.optimizer.zero_grad() + if data.scaler is None: + data.optimizer.step() + else: + data.scaler.step(data.optimizer) + data.scaler.update() - if config.device == 'cuda': - torch.cuda.synchronize() + data.optimizer.zero_grad() - with profile.train_misc: - losses.policy_loss += pg_loss.item() / total_minibatches - losses.value_loss += v_loss.item() / total_minibatches - losses.entropy += entropy_loss.item() / total_minibatches - losses.old_approx_kl += old_approx_kl.item() / total_minibatches - losses.approx_kl += approx_kl.item() / total_minibatches - losses.clipfrac += clipfrac.item() / total_minibatches - losses.grad_var += grad_var.item() / total_minibatches + if config.device == 'cuda': + torch.cuda.synchronize() - if data.use_diayn: - losses.diayn_loss += diayn_loss.item() / total_minibatches + # Reprioritize experience + values_np = experience.values.to('cpu', non_blocking=True).numpy() + dones_np = dones.to('cpu', non_blocking=True).numpy() + rewards_np = rewards.to('cpu', non_blocking=True).numpy() + stored_idxs = experience.stored_indices.to('cpu', non_blocking=True).numpy() + torch.cuda.synchronize() + advantages_np = compute_gae(dones_np, values_np, rewards_np, config.gamma, config.gae_lambda) + advantages = torch.as_tensor(advantages_np).to(config.device, non_blocking=True) + n_samples = experience.off_policy_rows + exp = experience.sample(advantages, n_samples) + experience.obs[experience.on_policy_rows:] = exp.obs + experience.actions[experience.on_policy_rows:] = exp.actions + experience.logprobs[experience.on_policy_rows:] = exp.logprobs + experience.dones[experience.on_policy_rows:] = exp.dones + experience.values[experience.on_policy_rows:] = exp.values + experience.rewards[experience.on_policy_rows:] = exp.rewards + + with profile.train_misc: + losses.policy_loss += pg_loss.item() / total_minibatches + losses.value_loss += v_loss.item() / total_minibatches + losses.entropy += entropy_loss.item() / total_minibatches + losses.old_approx_kl += old_approx_kl.item() / total_minibatches + losses.approx_kl += approx_kl.item() / total_minibatches + losses.clipfrac += clipfrac.item() / total_minibatches + losses.grad_var += grad_var.item() / total_minibatches + + if data.use_diayn: + losses.diayn_loss += diayn_loss.item() / total_minibatches if config.target_kl is not None: if approx_kl > config.target_kl: @@ -600,7 +617,9 @@ def train(data): y_true = experience.reward_block else: y_pred = experience.values.flatten() - y_true = experience.b_returns.flatten() + + # Probably not updated + y_true = advantages.flatten() + experience.values.flatten() var_y = y_true.var() explained_var = torch.nan if var_y == 0 else 1 - (y_true - y_pred).var() / var_y @@ -812,14 +831,16 @@ class Experience: '''Flat tensor storage and array views for faster indexing''' def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, hidden_size, obs_shape, obs_dtype, atn_shape, atn_dtype, cpu_offload=False, - device='cuda', policy=None, lstm_total_agents=0, + device='cuda', policy=None, lstm_total_agents=0, replay_factor=1, use_e3b=False, e3b_coef=0.1, e3b_lambda=10.0, use_diayn=False, diayn_archive=128, diayn_coef=0.1, use_p3o=False, p3o_horizon=32): if minibatch_size is None: minibatch_size = batch_size - num_rows = batch_size // bptt_horizon + self.on_policy_rows = batch_size // bptt_horizon + self.off_policy_rows = replay_factor * batch_size // bptt_horizon + num_rows = self.on_policy_rows + self.off_policy_rows self.num_rows = num_rows obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_dtype] @@ -836,11 +857,11 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, self.rewards=torch.zeros(num_rows, bptt_horizon, device=device) self.dones=torch.zeros(num_rows, bptt_horizon, device=device) self.truncateds=torch.zeros(num_rows, bptt_horizon, device=device) - + self.stored_indices = torch.zeros(num_rows, device=device, dtype=torch.int32) self.ep_lengths = torch.zeros(lstm_total_agents, device=device, dtype=torch.int32) self.ep_indices = torch.arange(lstm_total_agents, device=device, dtype=torch.int32) - self.stored_indices = torch.zeros(num_rows, device=device, dtype=torch.int32) - self.free_idx = 0 + self.free_idx = lstm_total_agents + assert self.free_idx <= num_rows self.use_e3b = use_e3b if use_e3b: @@ -870,9 +891,6 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, else: self.values = torch.zeros(num_rows, bptt_horizon, device=device) - self.sort_keys = np.zeros((batch_size, 3), dtype=np.int32) - self.sort_keys[:, 0] = np.arange(batch_size) - self.lstm_h = self.lstm_c = None if isinstance(policy, torch.nn.LSTM): assert lstm_total_agents > 0 @@ -901,7 +919,7 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, @property def full(self): - return self.free_idx >= self.num_rows + return self.free_idx >= self.on_policy_rows def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, env_id, mask): # Mask learner and Ensure indices do not exceed batch size @@ -929,19 +947,10 @@ def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, e self.obs[batch_rows, l] = cpu_obs if isinstance(env_id, slice): - self.stored_indices[batch_rows] = torch.arange(env_id.stop - env_id.start, device=self.device).int() + env_id.start + self.stored_indices[batch_rows] = torch.arange(env_id.start, env_id.stop, device=self.device).int() else: self.stored_indices[batch_rows] = env_id - l += 1 - self.ep_lengths[env_id] = l - full = l >= self.bptt_horizon - num_full = full.sum() - if num_full > 0: - self.ep_lengths[full] = 0 - self.ep_indices[full] = self.free_idx + torch.arange(num_full, device=self.device).int() - self.free_idx += num_full - if self.use_diayn: self.diayn_batch[dst] = state.diayn_z_idxs[gpu_inds] @@ -956,22 +965,40 @@ def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, e self.logprobs[batch_rows, l] = logprob self.rewards[batch_rows, l] = reward.to(self.rewards.device) # ??? self.dones[batch_rows, l] = done.float().to(self.dones.device) # ??? + + l += 1 + self.ep_lengths[env_id] = l + full = l >= self.bptt_horizon + num_full = full.sum() + if num_full > 0: + if isinstance(env_id, slice): + env_id = torch.arange(env_id.start, env_id.stop, device=self.device).int() + + full_ids = env_id[full] + self.ep_indices[full_ids] = self.free_idx + torch.arange(num_full, device=self.device).int() + self.ep_lengths[full_ids] = 0 + self.free_idx += num_full + self.step += 1 return action.cpu().numpy() - def sort_training_data(self): - idxs = np.lexsort((self.sort_keys[:, 2], self.sort_keys[:, 1])) - self.b_idxs_obs = torch.as_tensor(idxs.reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon - ).transpose(1,0,-1)).to(self.obs.device).long() - self.b_idxs = self.b_idxs_obs.to(self.device) - self.b_idxs_flat = self.b_idxs.reshape( - self.num_minibatches, self.minibatch_size) - self.sort_keys[:, 1:] = 0 - return idxs + def sample(self, advantages, n): + idx = torch.multinomial(advantages.abs().sum(axis=1), n) + advantages=advantages[idx] + values=self.values[idx] + return pufferlib.namespace( + actions=self.actions[idx], + logprobs=self.logprobs[idx], + rewards=self.rewards[idx], + dones=self.dones[idx], + obs=self.obs[idx], + advantages=advantages, + values=values, + returns=advantages + values, + ) + def flatten_batch(self, advantages_np, reward_block=None, mask_block=None): - self.free_idx = 0 advantages = torch.as_tensor(advantages_np).to(self.device, non_blocking=True) self.b_advantages = advantages.reshape( self.num_minibatches, self.minibatch_rows, self.bptt_horizon) diff --git a/config/default.ini b/config/default.ini index 77bbfa12b..ecdd0291f 100644 --- a/config/default.ini +++ b/config/default.ini @@ -48,6 +48,7 @@ data_dir = experiments checkpoint_interval = 200 batch_size = 524288 minibatch_size = 8192 +replay_factor = 1 # Accumulate gradients above this size max_minibatch_size = 16384 bptt_horizon = 64 diff --git a/config/ocean/breakout.ini b/config/ocean/breakout.ini index eea2e6018..f8d5fcd2c 100644 --- a/config/ocean/breakout.ini +++ b/config/ocean/breakout.ini @@ -3,33 +3,40 @@ package = ocean env_name = puffer_breakout policy_name = Policy rnn_name = Recurrent -; vec = multiprocessing +vec = multiprocessing [env] num_envs = 2048 +[policy] +hidden_size = 128 + +[rnn] +input_size = 128 +hidden_size = 128 + [train] -total_timesteps = 100_000_000 +total_timesteps = 80_000_000 checkpoint_interval = 50 -num_envs = 1 -num_workers = 1 +num_envs = 2 +num_workers = 2 env_batch_size = 1 -batch_size = 262144 -update_epochs = 2 -ent_coef = 0.0097648317226976 -gae_lambda = 0.8565811585596427 -gamma = 0.9660548302390047 -learning_rate = 0.006108033634877861 -max_grad_norm = 0.9084696417167661 -vf_coef = 0.699137796315858 +batch_size = 524288 +update_epochs = 64 +ent_coef = 0.004602497836498393 +gae_lambda = 0.8345374031042396 +gamma = 0.9964277976817042 +learning_rate = 0.02716585155000465 +max_grad_norm = 0.3833512851796203 +vf_coef = 2.177014788166991 minibatch_size = 8192 -bptt_horizon = 32 +bptt_horizon = 64 anneal_lr = True device = cuda -; adam_beta1 = 0.8619932484485815 -; adam_beta2 = 0.998659815024087 -; adam_eps = 1e-12 +adam_beta1 = 0.8619932484485815 +adam_beta2 = 0.998659815024087 +adam_eps = 1e-12 [sweep] method = protein diff --git a/demo.py b/demo.py index 6c374f3d7..462f585c1 100644 --- a/demo.py +++ b/demo.py @@ -177,7 +177,9 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10 batch_size = args['train']['batch_size'] while len(data.stats[target_metric]) < min_eval_points: stats, _ = clean_pufferl.evaluate(data) - data.experience.sort_keys[:] = 0 + # TODO: Beter place for this + data.experience.free_idx = 0 + data.experience.ep_lengths.zero_() steps_evaluated += batch_size clean_pufferl.mean_and_log(data) From c6c889300f1de3b0803bdc85b4c6c9d47905fccb Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 8 Apr 2025 19:21:41 +0000 Subject: [PATCH 03/26] Initial refactor, gae cuda kernel --- c_advantage.cu | 27 +- clean_pufferl.py | 906 +++++++++++++++++++++------------------------ demo.py | 4 +- pufferlib.cu | 35 +- pufferlib/utils.py | 14 +- 5 files changed, 497 insertions(+), 489 deletions(-) diff --git a/c_advantage.cu b/c_advantage.cu index dd8683f00..93275d75a 100644 --- a/c_advantage.cu +++ b/c_advantage.cu @@ -1,4 +1,4 @@ -__global__ void advantage_kernel( +__global__ void p3o_kernel( float* reward_block, // [num_steps, horizon] float* reward_mask, // [num_steps, horizon] float* values_mean, // [num_steps, horizon] @@ -57,7 +57,7 @@ __global__ void advantage_kernel( reward_mask[idx] = 1.0f; } - float bootstrap = 0.0f; + //float bootstrap = 0.0f; //if (k == horizon-1) { // bootstrap = buf[i*horizon + horizon - 1]*values_mean[i*horizon + horizon - 1]; //} @@ -85,3 +85,26 @@ __global__ void advantage_kernel( advantages[i] = R; bounds[i] = k; } + + +__global__ void gae_kernel( + float* values, // [num_steps, horizon] + float* rewards, // [num_steps, horizon] + float* dones, // [num_steps, horizon] + float* advantages, // [num_steps, horizon] + float gamma, + float gae_lambda, + int num_steps, + int horizon +) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + float lastgaelam = 0; + for (int t = horizon-2; t >= 0; t--) { + int idx = row*horizon + t; + int idx_next = idx + 1; + float nextnonterminal = 1.0 - dones[idx_next]; + float delta = rewards[idx_next] + gamma*values[idx_next]*nextnonterminal - values[idx]; + lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam; + advantages[idx] = lastgaelam; + } +} diff --git a/clean_pufferl.py b/clean_pufferl.py index bdfc8b6be..c7c10c0ce 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -24,20 +24,53 @@ # Fast Cython advantage functions #from c_advantage import rewards_and_masks, compute_gae -from c_advantage import compute_gae +#from c_advantage import compute_gae import torch from torch.utils.cpp_extension import load # Compile the CUDA kernel -''' cuda_module = load( - name='advantage_kernel', + name='compute_gae', sources=['pufferlib.cu'], verbose=True ) -''' + +def compute_gae( + values: torch.Tensor, # [num_steps, horizon] + rewards: torch.Tensor, # [num_steps, horizon] + dones: torch.Tensor, # [num_steps, horizon] + gamma: float, + gae_lambda: float, + ): + + num_steps = values.shape[0] + horizon = values.shape[1] + advantages = torch.zeros(num_steps, horizon, dtype=torch.float32, device=values.device) + + for t in [values, rewards, dones, advantages]: + assert t.ndim == 2 + assert t.shape[0] == num_steps + assert t.shape[1] == horizon + t.contiguous() + assert t.is_cuda, "All tensors must be on GPU" + + + cuda_module.compute_gae( + values, + rewards, + dones, + advantages, + gamma, + gae_lambda, + num_steps, + horizon, + ) + + torch.cuda.synchronize() + return advantages + def compute_advantages( reward_block: torch.Tensor, # [num_steps, horizon] @@ -100,18 +133,80 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): vecenv.async_reset(config.seed) obs_shape = vecenv.single_observation_space.shape - obs_dtype = vecenv.single_observation_space.dtype + obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[vecenv.single_observation_space.dtype] atn_shape = vecenv.single_action_space.shape - atn_dtype = vecenv.single_action_space.dtype + atn_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[vecenv.single_action_space.dtype] total_agents = vecenv.num_agents - experience = Experience(config.batch_size, config.bptt_horizon, - config.minibatch_size, config.max_minibatch_size, policy.hidden_size, obs_shape, obs_dtype, - atn_shape, atn_dtype, config.cpu_offload, config.device, policy, total_agents, config.replay_factor, - use_e3b=config.use_e3b, e3b_coef=config.e3b_coef, e3b_lambda=config.e3b_lambda, - use_diayn=config.use_diayn, diayn_archive=config.diayn_archive, diayn_coef=config.diayn_coef, - use_p3o=config.use_p3o, p3o_horizon=config.p3o_horizon + on_policy_rows = config.batch_size // config.bptt_horizon + off_policy_rows = config.replay_factor*config.batch_size // config.bptt_horizon + experience_rows = on_policy_rows + off_policy_rows + + pin = config.device == 'cuda' and config.cpu_offload + obs_device = config.device if not pin else 'cpu' + experience = pufferlib.namespace( + obs=torch.zeros(experience_rows, config.bptt_horizon, *obs_shape, + dtype=obs_dtype, pin_memory=pin, device='cpu' if pin else config.device), + actions=torch.zeros(experience_rows, config.bptt_horizon, *atn_shape, + dtype=atn_dtype, device=config.device), + logprobs=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), + rewards=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), + dones=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), + truncateds=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), ) + stored_indices = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) + ep_lengths = torch.zeros(total_agents, device=config.device, dtype=torch.int32) + ep_indices = torch.arange(total_agents, device=config.device, dtype=torch.int32) + free_idx = total_agents + + assert free_idx <= experience_rows + if config.use_e3b: + experience.e3b_inv = torch.eye(policy.hidden_size).repeat(total_agents, 1, 1).to(config.device) / config.e3b_lambda + experience.e3b_orig = experience.e3b_inv.clone() + experience.e3b_mean = None + experience.e3b_std = None + + if config.use_diayn: + # TODO: Check shapes + experience.diayn_archive = torch.nn.functional.one_hot(torch.arange(config.diayn_archive), config.diayn_archive).to(config.device).float() + experience.diayn_skills = torch.randint(0, config.diayn_archive, (total_agents,), dtype=torch.long, device=config.device) + experience.diayn_batch = torch.zeros(experience_rows, dtype=torch.long, device=config.device) + + if config.use_p3o: + batch_size = config.batch_size + p3o_horizon = config.p3o_horizon + device = config.device + experience.values_mean=torch.zeros(batch_size, p3o_horizon, device=device) + experience.values_std=torch.zeros(batch_size, p3o_horizon, device=device) + experience.reward_block = torch.zeros(batch_size, p3o_horizon, dtype=torch.float32, device=device) + experience.mask_block = torch.ones(batch_size, p3o_horizon, dtype=torch.float32, device=device) + experience.buf = torch.zeros(batch_size, p3o_horizon, dtype=torch.float32, device=device) + experience.advantages = torch.zeros(batch_size, dtype=torch.float32, device=device) + experience.bounds = torch.zeros(batch_size, dtype=torch.int32, device=device) + experience.vstd_max = 1.0 + else: + experience.values = torch.zeros(experience_rows, config.bptt_horizon, device=config.device) + + lstm_h = None + lstm_c = None + if isinstance(policy, torch.nn.LSTM): + assert total_agents > 0 + shape = (policy.num_layers, total_agents, policy.hidden_size) + lstm_h = torch.zeros(shape).to(config.device) + lstm_c = torch.zeros(shape).to(config.device) + + minibatch_size = min(config.minibatch_size, config.max_minibatch_size) + num_minibatches = config.batch_size / minibatch_size + if num_minibatches != int(num_minibatches): + raise ValueError('batch_size must be divisible by minibatch_size') + else: + num_minibatches = int(num_minibatches) + + minibatch_rows = minibatch_size / config.bptt_horizon + if minibatch_rows != int(minibatch_rows): + raise ValueError('minibatch_size must be divisible by bptt_horizon') + else: + minibatch_rows = int(minibatch_rows) uncompiled_policy = policy if config.compile: @@ -154,10 +249,12 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) scaler = None if config.precision == 'float32' else torch.amp.GradScaler() - amp_context = (nullcontext() if config.precision == 'float32' - else torch.amp.autocast(device_type='cuda', dtype=getattr(torch, config.precision))) - profile = Profile() + amp_context = nullcontext() + if config.precision != 'float32': + amp_context = torch.amp.autocast(device_type='cuda', dtype=getattr(torch, config.precision)) + + profile = Profile(amp_context) print_dashboard(config.env, utilization, 0, 0, profile, losses, {}, msg, clear=True) return pufferlib.namespace( @@ -168,7 +265,6 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): optimizer=optimizer, scheduler=scheduler, scaler=scaler, - amp_context=amp_context, experience=experience, profile=profile, losses=losses, @@ -189,6 +285,21 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): use_diayn=config.use_diayn, diayn_archive=config.diayn_archive, diayn_coef=config.diayn_coef, + # Do we use these? + ptr=0, + step=0, + lstm_h=lstm_h, + lstm_c=lstm_c, + minibatch_rows=minibatch_rows, + num_minibatches=num_minibatches, + stored_indices=stored_indices, + ep_lengths=ep_lengths, + ep_indices=ep_indices, + free_idx=free_idx, + on_policy_rows=on_policy_rows, + off_policy_rows=off_policy_rows, + experience_rows=experience_rows, + device=config.device, ) @pufferlib.utils.profile @@ -199,128 +310,114 @@ def evaluate(data): experience = data.experience policy = data.policy infos = defaultdict(list) - lstm_h = experience.lstm_h - lstm_c = experience.lstm_c + lstm_h = data.lstm_h + lstm_c = data.lstm_c + + while not full(data): + with profile.env: + o, r, d, t, info, env_id, mask = data.vecenv.recv() + + # Zero-copy indexing for contiguous env_id + if config.env_batch_size == 1: + gpu_env_id = cpu_env_id = slice(env_id[0], env_id[-1] + 1) + else: + cpu_env_id = env_id + gpu_env_id = torch.as_tensor(env_id).to(config.device, non_blocking=True) + + with profile.eval_misc: + done_mask = d + t + data.global_step += mask.sum() + + if data.use_diayn: + idxs = env_id[done_mask] + if len(idxs) > 0: + z_idxs = torch.randint(0, experience.diayn_archive.shape[0], (done_mask.sum(),)).to(config.device) + experience.diayn_skills[idxs] = z_idxs - with data.amp_context: - while not experience.full: - with profile.env: - o, r, d, t, info, env_id, mask = data.vecenv.recv() + with profile.eval_copy: + if data.use_e3b and done_mask.any(): + done_idxs = env_id[done_mask] + experience.e3b_inv[done_idxs] = experience.e3b_orig[done_idxs] - # Zero-copy indexing for contiguous env_id - if config.env_batch_size == 1: - gpu_env_id = cpu_env_id = slice(env_id[0], env_id[-1] + 1) + + o = torch.as_tensor(o) + o_device = o.to(config.device, non_blocking=True) + r = torch.as_tensor(r).to(config.device, non_blocking=True) + d = torch.as_tensor(d).to(config.device, non_blocking=True) + + h = None + c = None + if lstm_h is not None: + h = lstm_h[0, gpu_env_id] + c = lstm_c[0, gpu_env_id] + + with profile.eval_forward, torch.no_grad(): + state = pufferlib.namespace( + reward=r, + done=d, + env_id=gpu_env_id, + mask=mask, + lstm_h=h, + lstm_c=c, + ) + + if data.use_diayn: + z_idxs = experience.diayn_skills[env_id] + z = experience.diayn_archive[z_idxs] + state.diayn_z_idxs = z_idxs + state.diayn_z = z + + logits, value = policy(o_device, state) + action, logprob, _ = pufferlib.pytorch.sample_logits(logits, is_continuous=policy.is_continuous) + + if data.use_diayn: + diayn_policy = policy if lstm_h is None else policy.policy + q = diayn_policy.diayn_discriminator(state.hidden).squeeze() + r_diayn = torch.log_softmax(q, dim=-1).gather(-1, z_idxs.unsqueeze(-1)).squeeze() + r += config.diayn_coef*r_diayn# - np.log(1/data.diayn_archive) + state.diayn_z = z + state.diayn_z_idxs = z_idxs + + if data.use_e3b: + e3b = experience.e3b_inv[env_id] + phi = state.hidden.detach() + u = phi.unsqueeze(1) @ e3b + b = u @ phi.unsqueeze(2) + experience.e3b_inv[env_id] -= (u.mT @ u) / (1 + b) + done_inds = env_id[done_mask] + experience.e3b_inv[done_inds] = experience.e3b_orig[done_inds] + e3b_reward = b.squeeze() + + if experience.e3b_mean is None: + experience.e3b_mean = e3b_reward.mean() + experience.e3b_std = e3b_reward.std() else: - cpu_env_id = env_id - gpu_env_id = torch.as_tensor(env_id).to(config.device, non_blocking=True) - - with profile.eval_misc: - done_mask = d + t - data.global_step += mask.sum() - - if data.use_diayn: - idxs = env_id[done_mask] - if len(idxs) > 0: - z_idxs = torch.randint(0, experience.diayn_archive.shape[0], (done_mask.sum(),)).to(config.device) - experience.diayn_skills[idxs] = z_idxs - - with profile.eval_copy: - if data.use_e3b and done_mask.any(): - done_idxs = env_id[done_mask] - experience.e3b_inv[done_idxs] = experience.e3b_orig[done_idxs] - - - o = torch.as_tensor(o) - o_device = o.to(config.device, non_blocking=True) - r = torch.as_tensor(r).to(config.device, non_blocking=True) - d = torch.as_tensor(d).to(config.device, non_blocking=True) - - h = None - c = None - if lstm_h is not None: - h = lstm_h[0, gpu_env_id] - c = lstm_c[0, gpu_env_id] - - if config.device == 'cuda': - torch.cuda.synchronize() - - with profile.eval_forward, torch.no_grad(): - state = pufferlib.namespace( - reward=r, - done=d, - env_id=gpu_env_id, - mask=mask, - lstm_h=h, - lstm_c=c, - ) + w = data.e3b_norm + experience.e3b_mean = (1-w)*e3b_reward.mean() + w*experience.e3b_mean + experience.e3b_std = (1-w)*e3b_reward.std() + w*experience.e3b_std - if data.use_diayn: - z_idxs = experience.diayn_skills[env_id] - z = experience.diayn_archive[z_idxs] - state.diayn_z_idxs = z_idxs - state.diayn_z = z - - logits, value = policy(o_device, state) - action, logprob, _ = pufferlib.pytorch.sample_logits(logits, is_continuous=policy.is_continuous) - - if data.use_diayn: - diayn_policy = policy if lstm_h is None else policy.policy - q = diayn_policy.diayn_discriminator(state.hidden).squeeze() - r_diayn = torch.log_softmax(q, dim=-1).gather(-1, z_idxs.unsqueeze(-1)).squeeze() - r += config.diayn_coef*r_diayn# - np.log(1/data.diayn_archive) - state.diayn_z = z - state.diayn_z_idxs = z_idxs - - if data.use_e3b: - e3b = experience.e3b_inv[env_id] - phi = state.hidden.detach() - u = phi.unsqueeze(1) @ e3b - b = u @ phi.unsqueeze(2) - experience.e3b_inv[env_id] -= (u.mT @ u) / (1 + b) - done_inds = env_id[done_mask] - experience.e3b_inv[done_inds] = experience.e3b_orig[done_inds] - e3b_reward = b.squeeze() - - if experience.e3b_mean is None: - experience.e3b_mean = e3b_reward.mean() - experience.e3b_std = e3b_reward.std() - else: - w = data.e3b_norm - experience.e3b_mean = (1-w)*e3b_reward.mean() + w*experience.e3b_mean - experience.e3b_std = (1-w)*e3b_reward.std() + w*experience.e3b_std - - e3b_reward = (e3b_reward - experience.e3b_mean) / (experience.e3b_std + 1e-6) - e3b_reward = config.e3b_coef*e3b_reward - r += e3b_reward - - # Clip rewards - r = torch.clamp(r, -1, 1) - - if config.device == 'cuda': - torch.cuda.synchronize() - - with profile.eval_copy, torch.no_grad(): - if lstm_h is not None: - lstm_h[:, gpu_env_id] = state.lstm_h - lstm_c[:, gpu_env_id] = state.lstm_c - - if config.device == 'cuda': - torch.cuda.synchronize() - - with profile.eval_copy: - o = o if config.cpu_offload else o_device - actions = experience.store(state, o, o_device, value, action, logprob, r, d, gpu_env_id, mask) - - if config.device == 'cuda': - torch.cuda.synchronize() - - with profile.eval_misc: - for i in info: - for k, v in pufferlib.utils.unroll_nested_dict(i): - infos[k].append(v) - - with profile.env: - data.vecenv.send(actions) + e3b_reward = (e3b_reward - experience.e3b_mean) / (experience.e3b_std + 1e-6) + e3b_reward = config.e3b_coef*e3b_reward + r += e3b_reward + + # Clip rewards + r = torch.clamp(r, -1, 1) + + with profile.eval_copy, torch.no_grad(): + if lstm_h is not None: + lstm_h[:, gpu_env_id] = state.lstm_h + lstm_c[:, gpu_env_id] = state.lstm_c + + o = o if config.cpu_offload else o_device + actions = store(data, state, o, o_device, value, action, logprob, r, d, gpu_env_id, mask) + + with profile.eval_misc: + for i in info: + for k, v in pufferlib.utils.unroll_nested_dict(i): + infos[k].append(v) + + with profile.env: + data.vecenv.send(actions) with profile.eval_misc: for k, v in infos.items(): @@ -342,8 +439,8 @@ def evaluate(data): data.stats[k] += v # TODO: Better way to enable multiple collects - data.experience.ptr = 0 - data.experience.step = 0 + data.ptr = 0 + data.step = 0 return data.stats, infos @pufferlib.utils.profile @@ -360,8 +457,8 @@ def train(data): rewards = experience.rewards # TODO: Beter place for this - experience.free_idx = 0 - experience.ep_lengths.zero_() + data.free_idx = 0 + data.ep_lengths.zero_() with profile.train_misc: if config.use_p3o: @@ -378,7 +475,6 @@ def train(data): # we store experience to avoid this issue vstd_min = values_std.min().item() vstd_max = values_std.max().item() - torch.cuda.synchronize() mask_block.zero_() experience.buf.zero_() @@ -412,140 +508,123 @@ def train(data): torch.cuda.synchronize() # Optimizing the policy and value network - total_minibatches = experience.num_minibatches * config.update_epochs + total_minibatches = data.num_minibatches * config.update_epochs mean_pg_loss, mean_v_loss, mean_entropy_loss = 0, 0, 0 mean_old_kl, mean_kl, mean_clipfrac = 0, 0, 0 cross_entropy = torch.nn.CrossEntropyLoss() accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) for epoch in range(config.update_epochs): - values_np = experience.values.to('cpu', non_blocking=True).numpy() - dones_np = dones.to('cpu', non_blocking=True).numpy() - rewards_np = rewards.to('cpu', non_blocking=True).numpy() - stored_idxs = experience.stored_indices.to('cpu', non_blocking=True).numpy() - torch.cuda.synchronize() - advantages_np = compute_gae(dones_np, values_np, rewards_np, config.gamma, config.gae_lambda) - advantages = torch.as_tensor(advantages_np).to(config.device, non_blocking=True) + advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) n_samples = config.minibatch_size // config.bptt_horizon - exp = experience.sample(advantages, n_samples) - - obs = exp.obs - atn = exp.actions - log_probs = exp.logprobs - adv = exp.advantages - ret = exp.returns + batch = sample(data, advantages, n_samples) with profile.train_misc: state = pufferlib.namespace( - action=atn, + action=batch.actions, lstm_h=None, lstm_c=None, ) if config.use_diayn: - z_idxs = experience.b_diayn_z_idxs[mb] + z_idxs = batch.diayn_z_idxs if config.use_p3o: - val_mean = experience.b_values_mean[mb] - val_std = experience.b_values_std[mb] - rew_block = experience.b_reward_block[mb] - mask_block = experience.b_mask_block[mb] + val_mean = batch.values_mean + val_std = batch.values_std + rew_block = batch.reward_block + mask_block = batch.mask_block else: - val = exp.values.flatten() + val = batch.values.flatten() - if config.device == 'cuda': - torch.cuda.synchronize() + with profile.train_forward: + if not isinstance(data.policy, torch.nn.LSTM): + batch.obs = batch.obs.reshape(-1, *data.vecenv.single_observation_space.shape) - with data.amp_context: - with profile.train_forward: - if not isinstance(data.policy, torch.nn.LSTM): - obs = obs.reshape(-1, *data.vecenv.single_observation_space.shape) + logits, newvalue = data.policy.forward_train(batch.obs, state) + lstm_h = state.lstm_h + lstm_c = state.lstm_c + if lstm_h is not None: + lstm_h = lstm_h.detach() + if lstm_c is not None: + lstm_c = lstm_c.detach() - logits, newvalue = data.policy.forward_train(obs, state) - lstm_h = state.lstm_h - lstm_c = state.lstm_c - if lstm_h is not None: - lstm_h = lstm_h.detach() - if lstm_c is not None: - lstm_c = lstm_c.detach() + actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, + action=batch.actions, is_continuous=data.policy.is_continuous) - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, - action=atn, is_continuous=data.policy.is_continuous) + if config.device == 'cuda': + torch.cuda.synchronize() - if config.device == 'cuda': - torch.cuda.synchronize() + with profile.train_misc: + logratio = newlogprob - batch.logprobs.reshape(-1) + ratio = logratio.exp() - with profile.train_misc: - logratio = newlogprob - log_probs.reshape(-1) - ratio = logratio.exp() + # TODO: Only do this if we are KL clipping? Saves 1-2% compute + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfrac = ((ratio - 1.0).abs() > config.clip_coef).float().mean() + + adv = batch.advantages.reshape(-1) + if config.norm_adv: + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Policy loss + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp( + ratio, 1 - config.clip_coef, 1 + config.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() - # TODO: Only do this if we are KL clipping? Saves 1-2% compute - with torch.no_grad(): - # calculate approx_kl http://joschu.net/blog/kl-approx.html - old_approx_kl = (-logratio).mean() - approx_kl = ((ratio - 1) - logratio).mean() - clipfrac = ((ratio - 1.0).abs() > config.clip_coef).float().mean() - - adv = adv.reshape(-1) - if config.norm_adv: - adv = (adv - adv.mean()) / (adv.std() + 1e-8) - - # Policy loss - pg_loss1 = -adv * ratio - pg_loss2 = -adv * torch.clamp( - ratio, 1 - config.clip_coef, 1 + config.clip_coef + # Value loss + if config.use_p3o: + newvalue_mean = newvalue.mean.view(-1, config.p3o_horizon) + newvalue_std = newvalue.std.view(-1, config.p3o_horizon) + newvalue_var = torch.square(newvalue_std) + criterion = torch.nn.GaussianNLLLoss(reduction='none') + #v_loss = criterion(newvalue_mean[:, :32], rew_block[:, :32], newvalue_var[:, :32]) + v_loss = criterion(newvalue_mean, rew_block, newvalue_var) + v_loss = v_loss[:, :(horizon+3)] + mask_block = mask_block[:, :(horizon+3)] + #v_loss[:, horizon:] = 0 + #v_loss = (v_loss * mask_block).sum(axis=1) + #v_loss = (v_loss - v_loss.mean().item()) / (v_loss.std().item() + 1e-8) + #v_loss = v_loss.mean() + v_loss = v_loss[mask_block.bool()].mean() + #TODO: Count mask and sum + # There is going to have to be some sort of norm here. + # Right now, learning works at different horizons, but you need + # to retune hyperparameters. Ideally, horizon should be a stable + # param that zero-shots the same hypers + + # Faster than masking + #v_loss = (v_loss*mask_block[:, :32]).sum() / mask_block[:, :32].sum() + #v_loss = (v_loss*mask_block).sum() / mask_block.sum() + #v_loss = v_loss[mask_block.bool()].mean() + elif config.clip_vloss: + newvalue = newvalue.flatten() + ret = batch.returns.flatten() + v_loss_unclipped = (newvalue - ret) ** 2 + v_clipped = val + torch.clamp( + newvalue - val, + -config.vf_clip_coef, + config.vf_clip_coef, ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() - - # Value loss - if config.use_p3o: - newvalue_mean = newvalue.mean.view(-1, config.p3o_horizon) - newvalue_std = newvalue.std.view(-1, config.p3o_horizon) - newvalue_var = torch.square(newvalue_std) - criterion = torch.nn.GaussianNLLLoss(reduction='none') - #v_loss = criterion(newvalue_mean[:, :32], rew_block[:, :32], newvalue_var[:, :32]) - v_loss = criterion(newvalue_mean, rew_block, newvalue_var) - v_loss = v_loss[:, :(horizon+3)] - mask_block = mask_block[:, :(horizon+3)] - #v_loss[:, horizon:] = 0 - #v_loss = (v_loss * mask_block).sum(axis=1) - #v_loss = (v_loss - v_loss.mean().item()) / (v_loss.std().item() + 1e-8) - #v_loss = v_loss.mean() - v_loss = v_loss[mask_block.bool()].mean() - #TODO: Count mask and sum - # There is going to have to be some sort of norm here. - # Right now, learning works at different horizons, but you need - # to retune hyperparameters. Ideally, horizon should be a stable - # param that zero-shots the same hypers - - # Faster than masking - #v_loss = (v_loss*mask_block[:, :32]).sum() / mask_block[:, :32].sum() - #v_loss = (v_loss*mask_block).sum() / mask_block.sum() - #v_loss = v_loss[mask_block.bool()].mean() - elif config.clip_vloss: - newvalue = newvalue.flatten() - ret = ret.flatten() - v_loss_unclipped = (newvalue - ret) ** 2 - v_clipped = val + torch.clamp( - newvalue - val, - -config.vf_clip_coef, - config.vf_clip_coef, - ) - v_loss_clipped = (v_clipped - ret) ** 2 - v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) - v_loss = 0.5 * v_loss_max.mean() - else: - newvalue = newvalue.flatten() - v_loss = 0.5 * ((newvalue - ret) ** 2).mean() + v_loss_clipped = (v_clipped - ret) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + newvalue = newvalue.flatten() + v_loss = 0.5 * ((newvalue - ret) ** 2).mean() - entropy_loss = entropy.mean() - loss = pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef + entropy_loss = entropy.mean() + loss = pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef - with profile.custom: - if config.use_diayn: - diayn_discriminator = data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator - q = diayn_discriminator(state.hidden).squeeze() - diayn_loss = cross_entropy(q, z_idxs) - loss += config.diayn_loss_coef*diayn_loss - torch.cuda.synchronize() + with profile.custom: + if config.use_diayn: + diayn_discriminator = data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator + q = diayn_discriminator(state.hidden).squeeze() + diayn_loss = cross_entropy(q, z_idxs) + loss += config.diayn_loss_coef*diayn_loss with profile.learn: if data.scaler is None: @@ -572,25 +651,13 @@ def train(data): data.optimizer.zero_grad() - if config.device == 'cuda': - torch.cuda.synchronize() - # Reprioritize experience - values_np = experience.values.to('cpu', non_blocking=True).numpy() - dones_np = dones.to('cpu', non_blocking=True).numpy() - rewards_np = rewards.to('cpu', non_blocking=True).numpy() - stored_idxs = experience.stored_indices.to('cpu', non_blocking=True).numpy() - torch.cuda.synchronize() - advantages_np = compute_gae(dones_np, values_np, rewards_np, config.gamma, config.gae_lambda) - advantages = torch.as_tensor(advantages_np).to(config.device, non_blocking=True) - n_samples = experience.off_policy_rows - exp = experience.sample(advantages, n_samples) - experience.obs[experience.on_policy_rows:] = exp.obs - experience.actions[experience.on_policy_rows:] = exp.actions - experience.logprobs[experience.on_policy_rows:] = exp.logprobs - experience.dones[experience.on_policy_rows:] = exp.dones - experience.values[experience.on_policy_rows:] = exp.values - experience.rewards[experience.on_policy_rows:] = exp.rewards + advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) + + n_samples = data.off_policy_rows + exp = sample(data, advantages, n_samples) + for k, v in experience.items(): + v[data.on_policy_rows:] = exp[k] with profile.train_misc: losses.policy_loss += pg_loss.item() / total_minibatches @@ -643,10 +710,96 @@ def train(data): save_checkpoint(data) data.msg = f'Checkpoint saved at update {data.epoch}' - torch.cuda.synchronize() - return logs +def full(data): + return data.free_idx >= data.on_policy_rows + +def store(data, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, env_id, mask): + # Mask learner and Ensure indices do not exceed batch size + exp = data.experience + ptr = data.ptr + indices = np.where(mask)[0] + num_indices = indices.size + end = ptr + num_indices + dst = slice(ptr, end) + + # Zero-copy indexing for contiguous env_id + ''' + if num_indices == mask.size and isinstance(env_id, slice): + gpu_inds = cpu_inds = slice(0, min(self.batch_size - ptr, num_indices)) + else: + cpu_inds = indices[:self.batch_size - ptr] + gpu_inds = torch.as_tensor(cpu_inds).to(self.obs.device, non_blocking=True) + ''' + + batch_rows = data.ep_indices[env_id] + l = data.ep_lengths[env_id] + + if exp.obs.device.type == 'cuda': + exp.obs[batch_rows, l] = gpu_obs + else: + exp.obs[batch_rows, l] = cpu_obs + + if isinstance(env_id, slice): + data.stored_indices[batch_rows] = torch.arange(env_id.start, env_id.stop, device=data.device).int() + else: + data.stored_indices[batch_rows] = env_id + + + if data.use_diayn: + data.diayn_batch[dst] = state.diayn_z_idxs[gpu_inds] + + if data.use_p3o: + exp.values_mean[dst] = value.mean[gpu_inds] + exp.values_std[dst] = value.std[gpu_inds] + else: + exp.values[batch_rows, l] = value.flatten() + + exp.actions[batch_rows, l] = action + exp.logprobs[batch_rows, l] = logprob + exp.rewards[batch_rows, l] = reward.to(exp.rewards.device) # ??? + exp.dones[batch_rows, l] = done.float().to(exp.dones.device) # ??? + + l += 1 + data.ep_lengths[env_id] = l + full = l >= data.config.bptt_horizon + num_full = full.sum() + if num_full > 0: + if isinstance(env_id, slice): + env_id = torch.arange(env_id.start, env_id.stop, device=data.device).int() + + full_ids = env_id[full] + data.ep_indices[full_ids] = data.free_idx + torch.arange(num_full, device=data.device).int() + data.ep_lengths[full_ids] = 0 + data.free_idx += num_full + + data.step += 1 + return action.cpu().numpy() + +def sample(data, advantages, n, reward_block=None, mask_block=None): + exp = data.experience + idx = torch.multinomial(advantages.abs().sum(axis=1), n) + output = {k: v[idx] for k, v in exp.items()} + + if data.use_p3o: + output['reward_block'] = reward_block[idx] + output['mask_block'] = mask_block[idx] + output['values_mean'] = exp.values_mean[idx] + output['values_std'] = exp.values_std[idx] + else: + output['values'] = exp.values[idx] + output['advantages'] = advantages[idx] + output['returns'] = advantages[idx] + exp.values[idx] + + if data.use_diayn: + output['diayn_z_idxs'] = exp.diayn_batch[idx] + output['diayn_z'] = exp.diayn_skills[idx] + + return pufferlib.namespace(**output) + + + def compute_pg_loss(log_probs, newlogprob, adv, clip_coef): logratio = newlogprob - log_probs.reshape(-1) ratio = logratio.exp() @@ -754,15 +907,16 @@ class Profile: train_copy_time: ... = 0 train_misc_time: ... = 0 custom_time: ... = 0 - def __init__(self): + def __init__(self, amp_context): self.start = time.time() self.env = pufferlib.utils.Profiler() - self.eval_forward = pufferlib.utils.Profiler() - self.eval_copy = pufferlib.utils.Profiler() + # TODO: Figure out which of these need amp + self.eval_forward = pufferlib.utils.Profiler(amp_context=amp_context) + self.eval_copy = pufferlib.utils.Profiler(amp_context=amp_context) self.eval_misc = pufferlib.utils.Profiler() - self.train_forward = pufferlib.utils.Profiler() + self.train_forward = pufferlib.utils.Profiler(amp_context=amp_context) self.learn = pufferlib.utils.Profiler() - self.train_copy = pufferlib.utils.Profiler() + self.train_copy = pufferlib.utils.Profiler(amp_context=amp_context) self.train_misc = pufferlib.utils.Profiler() self.custom = pufferlib.utils.Profiler() self.prev_steps = 0 @@ -827,214 +981,6 @@ def make_losses(): grad_var=0, ) -class Experience: - '''Flat tensor storage and array views for faster indexing''' - def __init__(self, batch_size, bptt_horizon, minibatch_size, max_minibatch_size, hidden_size, - obs_shape, obs_dtype, atn_shape, atn_dtype, cpu_offload=False, - device='cuda', policy=None, lstm_total_agents=0, replay_factor=1, - use_e3b=False, e3b_coef=0.1, e3b_lambda=10.0, - use_diayn=False, diayn_archive=128, diayn_coef=0.1, - use_p3o=False, p3o_horizon=32): - if minibatch_size is None: - minibatch_size = batch_size - - self.on_policy_rows = batch_size // bptt_horizon - self.off_policy_rows = replay_factor * batch_size // bptt_horizon - num_rows = self.on_policy_rows + self.off_policy_rows - self.num_rows = num_rows - - obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_dtype] - atn_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[atn_dtype] - pin = device == 'cuda' and cpu_offload - obs_device = device if not pin else 'cpu' - self.obs_shape = obs_shape - self.atn_shape = atn_shape - self.obs=torch.zeros(num_rows, bptt_horizon, *obs_shape, dtype=obs_dtype, - pin_memory=pin, device=device if not pin else 'cpu') - self.actions=torch.zeros(num_rows, bptt_horizon, *atn_shape, - dtype=atn_dtype, device=device) - self.logprobs=torch.zeros(num_rows, bptt_horizon, device=device) - self.rewards=torch.zeros(num_rows, bptt_horizon, device=device) - self.dones=torch.zeros(num_rows, bptt_horizon, device=device) - self.truncateds=torch.zeros(num_rows, bptt_horizon, device=device) - self.stored_indices = torch.zeros(num_rows, device=device, dtype=torch.int32) - self.ep_lengths = torch.zeros(lstm_total_agents, device=device, dtype=torch.int32) - self.ep_indices = torch.arange(lstm_total_agents, device=device, dtype=torch.int32) - self.free_idx = lstm_total_agents - assert self.free_idx <= num_rows - - self.use_e3b = use_e3b - if use_e3b: - self.e3b_inv = torch.eye(hidden_size).repeat(lstm_total_agents, 1, 1).to(device) / e3b_lambda - self.e3b_orig = self.e3b_inv.clone() - self.e3b_mean = None - self.e3b_std = None - - self.use_diayn = use_diayn - if use_diayn: - #self.diayn_archive = torch.randn(diayn_archive, hidden_size, dtype=torch.float32, device=device) - self.diayn_archive = torch.nn.functional.one_hot(torch.arange(diayn_archive), diayn_archive).to(device).float() - self.diayn_skills = torch.randint(0, diayn_archive, (lstm_total_agents,), dtype=torch.long, device=device) - self.diayn_batch = torch.zeros(batch_size, dtype=torch.long, device=device) - - self.use_p3o = use_p3o - self.p3o_horizon = p3o_horizon - if use_p3o: - self.values_mean=torch.zeros(batch_size, p3o_horizon, device=device) - self.values_std=torch.zeros(batch_size, p3o_horizon, device=device) - self.reward_block = torch.zeros(batch_size, p3o_horizon, dtype=torch.float32, device=device) - self.mask_block = torch.ones(batch_size, p3o_horizon, dtype=torch.float32, device=device) - self.buf = torch.zeros(batch_size, p3o_horizon, dtype=torch.float32, device=device) - self.advantages = torch.zeros(batch_size, dtype=torch.float32, device=device) - self.bounds = torch.zeros(batch_size, dtype=torch.int32, device=device) - self.vstd_max = 1.0 - else: - self.values = torch.zeros(num_rows, bptt_horizon, device=device) - - self.lstm_h = self.lstm_c = None - if isinstance(policy, torch.nn.LSTM): - assert lstm_total_agents > 0 - shape = (policy.num_layers, lstm_total_agents, policy.hidden_size) - self.lstm_h = torch.zeros(shape).to(device) - self.lstm_c = torch.zeros(shape).to(device) - - minibatch_size = min(minibatch_size, max_minibatch_size) - num_minibatches = batch_size / minibatch_size - self.num_minibatches = int(num_minibatches) - if self.num_minibatches != num_minibatches: - raise ValueError('batch_size must be divisible by minibatch_size') - - minibatch_rows = minibatch_size / bptt_horizon - self.minibatch_rows = int(minibatch_rows) - if self.minibatch_rows != minibatch_rows: - raise ValueError('minibatch_size must be divisible by bptt_horizon') - - self.batch_size = batch_size - self.bptt_horizon = bptt_horizon - self.p3o_horizon = p3o_horizon - self.minibatch_size = minibatch_size - self.device = device - self.ptr = 0 - self.step = 0 - - @property - def full(self): - return self.free_idx >= self.on_policy_rows - - def store(self, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, env_id, mask): - # Mask learner and Ensure indices do not exceed batch size - ptr = self.ptr - indices = np.where(mask)[0] - num_indices = indices.size - end = ptr + num_indices - dst = slice(ptr, end) - - # Zero-copy indexing for contiguous env_id - ''' - if num_indices == mask.size and isinstance(env_id, slice): - gpu_inds = cpu_inds = slice(0, min(self.batch_size - ptr, num_indices)) - else: - cpu_inds = indices[:self.batch_size - ptr] - gpu_inds = torch.as_tensor(cpu_inds).to(self.obs.device, non_blocking=True) - ''' - - batch_rows = self.ep_indices[env_id] - l = self.ep_lengths[env_id] - - if self.obs.device.type == 'cuda': - self.obs[batch_rows, l] = gpu_obs - else: - self.obs[batch_rows, l] = cpu_obs - - if isinstance(env_id, slice): - self.stored_indices[batch_rows] = torch.arange(env_id.start, env_id.stop, device=self.device).int() - else: - self.stored_indices[batch_rows] = env_id - - - if self.use_diayn: - self.diayn_batch[dst] = state.diayn_z_idxs[gpu_inds] - - if self.use_p3o: - self.values_mean[dst] = value.mean[gpu_inds] - self.values_std[dst] = value.std[gpu_inds] - else: - self.values[batch_rows, l] = value.flatten() - - self.actions[batch_rows, l] = action - self.logprobs[batch_rows, l] = logprob - self.rewards[batch_rows, l] = reward.to(self.rewards.device) # ??? - self.dones[batch_rows, l] = done.float().to(self.dones.device) # ??? - - l += 1 - self.ep_lengths[env_id] = l - full = l >= self.bptt_horizon - num_full = full.sum() - if num_full > 0: - if isinstance(env_id, slice): - env_id = torch.arange(env_id.start, env_id.stop, device=self.device).int() - - full_ids = env_id[full] - self.ep_indices[full_ids] = self.free_idx + torch.arange(num_full, device=self.device).int() - self.ep_lengths[full_ids] = 0 - self.free_idx += num_full - - self.step += 1 - return action.cpu().numpy() - - def sample(self, advantages, n): - idx = torch.multinomial(advantages.abs().sum(axis=1), n) - advantages=advantages[idx] - values=self.values[idx] - return pufferlib.namespace( - actions=self.actions[idx], - logprobs=self.logprobs[idx], - rewards=self.rewards[idx], - dones=self.dones[idx], - obs=self.obs[idx], - advantages=advantages, - values=values, - returns=advantages + values, - ) - - - def flatten_batch(self, advantages_np, reward_block=None, mask_block=None): - advantages = torch.as_tensor(advantages_np).to(self.device, non_blocking=True) - self.b_advantages = advantages.reshape( - self.num_minibatches, self.minibatch_rows, self.bptt_horizon) - self.b_actions = self.actions.to(self.device, non_blocking=True).reshape( - self.num_minibatches, self.minibatch_rows, self.bptt_horizon, -1) - self.b_logprobs = self.logprobs.to(self.device, non_blocking=True).reshape( - self.num_minibatches, self.minibatch_rows, self.bptt_horizon) - self.b_dones = self.dones.to(self.device, non_blocking=True).reshape( - self.num_minibatches, self.minibatch_rows, self.bptt_horizon) - self.b_obs = self.obs.to(self.device, non_blocking=True).reshape( - self.num_minibatches, self.minibatch_rows, self.bptt_horizon, *self.obs_shape) - - if self.use_p3o: - self.reward_block = torch.as_tensor(reward_block).to(self.device) - self.b_reward_block = self.reward_block.reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon, self.p3o_horizon - ).transpose(0, 1).reshape(self.num_minibatches, self.minibatch_size, self.p3o_horizon) - - b_mask_block = torch.as_tensor(mask_block).to(self.device) - self.b_mask_block = b_mask_block.reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon, self.p3o_horizon - ).transpose(0, 1).reshape(self.num_minibatches, self.minibatch_size, self.p3o_horizon) - - self.b_values_mean = self.values_mean.to(self.device, non_blocking=True)[b_flat] - self.b_values_std = self.values_std.to(self.device, non_blocking=True)[b_flat] - self.b_returns = self.buf.to(self.device, non_blocking=True).reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon, self.p3o_horizon - ).transpose(0, 1).reshape(self.num_minibatches, self.minibatch_size, self.p3o_horizon) - else: - self.b_values = self.values.to(self.device, non_blocking=True).reshape( - self.num_minibatches, self.minibatch_rows, self.bptt_horizon) - self.b_returns = self.b_advantages + self.b_values - if self.use_diayn: - self.b_diayn_z_idxs = self.diayn_batch.to(self.device, non_blocking=True)[b_flat] - self.b_diayn_z = self.diayn_archive[self.b_diayn_z_idxs] - class Utilization(Thread): def __init__(self, delay=1, maxlen=20): super().__init__() diff --git a/demo.py b/demo.py index 462f585c1..2adf463f2 100644 --- a/demo.py +++ b/demo.py @@ -178,8 +178,8 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10 while len(data.stats[target_metric]) < min_eval_points: stats, _ = clean_pufferl.evaluate(data) # TODO: Beter place for this - data.experience.free_idx = 0 - data.experience.ep_lengths.zero_() + data.free_idx = 0 + data.ep_lengths.zero_() steps_evaluated += batch_size clean_pufferl.mean_and_log(data) diff --git a/pufferlib.cu b/pufferlib.cu index 60d9bce22..f56ef82ce 100644 --- a/pufferlib.cu +++ b/pufferlib.cu @@ -3,7 +3,7 @@ // Pybind11 module definition PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("advantage_kernel", [](torch::Tensor reward_block, + m.def("compute_p3o", [](torch::Tensor reward_block, torch::Tensor reward_mask, torch::Tensor values_mean, torch::Tensor values_std, @@ -20,7 +20,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { int threads_per_block = 256; int blocks = (num_steps + threads_per_block - 1) / threads_per_block; - advantage_kernel<<>>( + p3o_kernel<<>>( reward_block.data_ptr(), reward_mask.data_ptr(), values_mean.data_ptr(), @@ -41,5 +41,34 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (err != cudaSuccess) { throw std::runtime_error(cudaGetErrorString(err)); } - }, "Compute advantages with CUDA"); + }, "Compute p3o advantages with CUDA"); + + m.def("compute_gae", [](torch::Tensor values, + torch::Tensor rewards, + torch::Tensor dones, + torch::Tensor advantages, + float gamma, + float gae_lambda, + int num_steps, + int horizon) { + // Launch the kernel + int threads_per_block = 256; + int blocks = (num_steps + threads_per_block - 1) / threads_per_block; + + gae_kernel<<>>( + values.data_ptr(), + rewards.data_ptr(), + dones.data_ptr(), + advantages.data_ptr(), + gamma, + gae_lambda, + num_steps, + horizon + ); + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error(cudaGetErrorString(err)); + } + }, "Compute GAE with CUDA"); } diff --git a/pufferlib/utils.py b/pufferlib/utils.py index ad50c3054..addb7e6b7 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -1,6 +1,7 @@ from pdb import set_trace as T from collections import OrderedDict +from contextlib import nullcontext import numpy as np @@ -244,7 +245,7 @@ def format_bytes(size): return f'{size} B' class Profiler: - def __init__(self, elapsed=True, calls=True, memory=False, pytorch_memory=False): + def __init__(self, elapsed=True, calls=True, memory=False, pytorch_memory=False, sync_cuda=True, amp_context=nullcontext()): self.elapsed = 0 if elapsed else None self.calls = 0 if calls else None self.memory = None @@ -255,14 +256,17 @@ def __init__(self, elapsed=True, calls=True, memory=False, pytorch_memory=False) self.track_calls = calls self.track_memory = memory self.track_pytorch_memory = pytorch_memory + self.sync_cuda = sync_cuda if memory: self.process = psutil.Process() - if pytorch_memory: + if pytorch_memory or sync_cuda: import torch self.torch = torch + self.amp_context = amp_context + @property def serial(self): return { @@ -280,6 +284,9 @@ def delta(self): return ret def __enter__(self): + if self.sync_cuda: + self.torch.cuda.synchronize() + self.amp_context.__enter__() if self.track_elapsed: self.start_time = time.perf_counter() if self.track_memory: @@ -300,6 +307,9 @@ def __exit__(self, *args): if self.track_pytorch_memory: self.end_torch_mem = self.torch.cuda.memory_allocated() self.pytorch_memory = self.end_torch_mem - self.start_torch_mem + self.amp_context.__exit__(None, None, None) + if self.sync_cuda: + self.torch.cuda.synchronize() def __repr__(self): parts = [] From fd0c47dc9937d0b915e170875a6c1fb8e39c221a Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 8 Apr 2025 23:12:58 +0000 Subject: [PATCH 04/26] Running, decent perf --- clean_pufferl.py | 159 ++++++++++++++++------------------- config/default.ini | 2 +- pufferlib/models.py | 1 + pufferlib/ocean/grid/grid.py | 2 +- 4 files changed, 77 insertions(+), 87 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index c7c10c0ce..2c5c2245f 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -139,7 +139,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): total_agents = vecenv.num_agents on_policy_rows = config.batch_size // config.bptt_horizon - off_policy_rows = config.replay_factor*config.batch_size // config.bptt_horizon + off_policy_rows = int(config.replay_factor*config.batch_size // config.bptt_horizon) experience_rows = on_policy_rows + off_policy_rows pin = config.device == 'cuda' and config.cpu_offload @@ -154,6 +154,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): dones=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), truncateds=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), ) + ep_uses = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) stored_indices = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) ep_lengths = torch.zeros(total_agents, device=config.device, dtype=torch.int32) ep_indices = torch.arange(total_agents, device=config.device, dtype=torch.int32) @@ -293,6 +294,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): minibatch_rows=minibatch_rows, num_minibatches=num_minibatches, stored_indices=stored_indices, + ep_uses=ep_uses, ep_lengths=ep_lengths, ep_indices=ep_indices, free_idx=free_idx, @@ -446,100 +448,77 @@ def evaluate(data): @pufferlib.utils.profile def train(data): config, profile, experience = data.config, data.profile, data.experience - data.losses = make_losses() - losses = data.losses - - with profile.train_copy: - #idxs = experience.sort_training_data() - #dones = experience.dones[idxs] - #rewards = experience.rewards[idxs] - dones = experience.dones - rewards = experience.rewards + losses = make_losses() + data.losses = losses - # TODO: Beter place for this + # TODO: Better place for this data.free_idx = 0 data.ep_lengths.zero_() - - with profile.train_misc: - if config.use_p3o: - reward_block = experience.reward_block - mask_block = experience.mask_block - values_mean = experience.values_mean[idxs] - values_std = experience.values_std[idxs] - advantages = experience.advantages - - # Note: This function gets messed up by computing across - # episode bounds. Because we store experience in a flat buffer, - # bounds can be crossed even after handling dones. This prevent - # our method from scaling to longer horizons. TODO: Redo the way - # we store experience to avoid this issue - vstd_min = values_std.min().item() - vstd_max = values_std.max().item() - - mask_block.zero_() - experience.buf.zero_() - reward_block.zero_() - r_mean = rewards.mean().item() - r_std = rewards.std().item() - advantages.zero_() - experience.bounds.zero_() - - ''' - if data.epoch == 0: - values_std[:] = r_std - with torch.no_grad(): - data.policy.policy.value_logstd[:] = np.log(r_std) - ''' - - # TODO: Rename vstd to r_std - advantages = compute_advantages(reward_block, mask_block, values_mean, values_std, - experience.buf, dones, rewards, advantages, experience.bounds, - r_std, data.puf, config.p3o_horizon) - - horizon = torch.where(values_std[0] > 0.95*r_std)[0] - horizon = horizon[0].item()+1 if len(horizon) else 1 - if horizon < 16: - horizon = 16 - - advantages = advantages.cpu().numpy() - torch.cuda.synchronize() - - experience.flatten_batch(advantages, reward_block, mask_block) - torch.cuda.synchronize() + data.ep_uses.zero_() # Optimizing the policy and value network total_minibatches = data.num_minibatches * config.update_epochs - mean_pg_loss, mean_v_loss, mean_entropy_loss = 0, 0, 0 - mean_old_kl, mean_kl, mean_clipfrac = 0, 0, 0 cross_entropy = torch.nn.CrossEntropyLoss() accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) - for epoch in range(config.update_epochs): - advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) - n_samples = config.minibatch_size // config.bptt_horizon - batch = sample(data, advantages, n_samples) - + for mb in range(total_minibatches): with profile.train_misc: + if config.use_p3o: + # Note: This function gets messed up by computing across + # episode bounds. Because we store experience in a flat buffer, + # bounds can be crossed even after handling dones. This prevent + # our method from scaling to longer horizons. TODO: Redo the way + # we store experience to avoid this issue + vstd_min = experience.values_std.min().item() + vstd_max = experience.values_std.max().item() + + data.mask_block.zero_() + data.buf.zero_() + data.reward_block.zero_() + data.bounds.zero_() + + r_mean = experience.rewards.mean().item() + r_std = experience.rewards.std().item() + + # TODO: Rename vstd to r_std + advantages = compute_advantages( + experience.reward_block, experience.mask_block, + experience.values_mean, experience.values_std, + experience.buf, experience.dones, experience.rewards, + experience.bounds, r_std, data.puf, config.p3o_horizon + ) + + horizon = torch.where(experience.values_std[0] > 0.95*r_std)[0] + horizon = horizon[0].item()+1 if len(horizon) else 1 + if horizon < 16: + horizon = 16 + + advantages = advantages.cpu().numpy() + torch.cuda.synchronize() + else: + advantages = compute_gae(experience.values, experience.rewards, + experience.dones, config.gamma, config.gae_lambda) + + n_samples = config.minibatch_size // config.bptt_horizon + batch = sample(data, advantages, n_samples) + state = pufferlib.namespace( action=batch.actions, lstm_h=None, lstm_c=None, ) + if config.use_diayn: z_idxs = batch.diayn_z_idxs - if config.use_p3o: - val_mean = batch.values_mean - val_std = batch.values_std - rew_block = batch.reward_block - mask_block = batch.mask_block - else: - val = batch.values.flatten() - with profile.train_forward: if not isinstance(data.policy, torch.nn.LSTM): batch.obs = batch.obs.reshape(-1, *data.vecenv.single_observation_space.shape) logits, newvalue = data.policy.forward_train(batch.obs, state) + # TODO: Currently only returning traj shaped value as a hack + with torch.no_grad(): + experience.values[batch.idx] = newvalue + lstm_h = state.lstm_h lstm_c = state.lstm_c if lstm_h is not None: @@ -582,7 +561,7 @@ def train(data): newvalue_var = torch.square(newvalue_std) criterion = torch.nn.GaussianNLLLoss(reduction='none') #v_loss = criterion(newvalue_mean[:, :32], rew_block[:, :32], newvalue_var[:, :32]) - v_loss = criterion(newvalue_mean, rew_block, newvalue_var) + v_loss = criterion(newvalue_mean, batch.reward_block, newvalue_var) v_loss = v_loss[:, :(horizon+3)] mask_block = mask_block[:, :(horizon+3)] #v_loss[:, horizon:] = 0 @@ -604,6 +583,7 @@ def train(data): newvalue = newvalue.flatten() ret = batch.returns.flatten() v_loss_unclipped = (newvalue - ret) ** 2 + val = batch.values.flatten() v_clipped = val + torch.clamp( newvalue - val, -config.vf_clip_coef, @@ -640,7 +620,7 @@ def train(data): grad_var = grads.var(0).mean() * config.minibatch_size data.msg = f'Gradient variance: {grad_var.item():.3f}' - if (epoch + 1) % accumulate_minibatches == 0: + if (mb + 1) % accumulate_minibatches == 0: torch.nn.utils.clip_grad_norm_(data.policy.parameters(), config.max_grad_norm) if data.scaler is None: @@ -651,13 +631,6 @@ def train(data): data.optimizer.zero_grad() - # Reprioritize experience - advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) - - n_samples = data.off_policy_rows - exp = sample(data, advantages, n_samples) - for k, v in experience.items(): - v[data.on_policy_rows:] = exp[k] with profile.train_misc: losses.policy_loss += pg_loss.item() / total_minibatches @@ -675,6 +648,18 @@ def train(data): if approx_kl > config.target_kl: break + # Reprioritize experience + advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) + + ep_uses = data.ep_uses + data.max_uses = ep_uses.max().item() + data.mean_uses = ep_uses.float().mean().item() + + n_samples = data.off_policy_rows + exp = sample(data, advantages, n_samples) + for k, v in experience.items(): + v[data.on_policy_rows:] = exp[k] + with profile.train_misc: if config.anneal_lr: data.scheduler.step() @@ -779,8 +764,12 @@ def store(data, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, e def sample(data, advantages, n, reward_block=None, mask_block=None): exp = data.experience - idx = torch.multinomial(advantages.abs().sum(axis=1), n) + #idx = torch.multinomial(advantages.abs().sum(axis=1), n) + idx = torch.randint(0, advantages.shape[0], (n,), device=data.device) + #_, idx = torch.topk(advantages.abs().sum(axis=1), n) + data.ep_uses[idx] += 1 output = {k: v[idx] for k, v in exp.items()} + output['idx'] = idx if data.use_p3o: output['reward_block'] = reward_block[idx] @@ -798,8 +787,6 @@ def sample(data, advantages, n, reward_block=None, mask_block=None): return pufferlib.namespace(**output) - - def compute_pg_loss(log_probs, newlogprob, adv, clip_coef): logratio = newlogprob - log_probs.reshape(-1) ratio = logratio.exp() @@ -860,6 +847,8 @@ def mean_and_log(data): 'agent_steps': agent_steps, 'epoch': epoch, 'learning_rate': learning_rate, + 'max_uses': data.max_uses, + 'mean_uses': data.mean_uses, **{f'environment/{k}': v for k, v in environment.items()}, **{f'losses/{k}': v for k, v in losses.items()}, **{f'performance/{k}': v for k, v in performance.items()}, diff --git a/config/default.ini b/config/default.ini index ecdd0291f..1290ddf32 100644 --- a/config/default.ini +++ b/config/default.ini @@ -48,7 +48,7 @@ data_dir = experiments checkpoint_interval = 200 batch_size = 524288 minibatch_size = 8192 -replay_factor = 1 +replay_factor = 0.125 # Accumulate gradients above this size max_minibatch_size = 16384 bptt_horizon = 64 diff --git a/pufferlib/models.py b/pufferlib/models.py index e95bf19a4..627cfdefd 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -208,6 +208,7 @@ def forward_train(self, observations, state): hidden = hidden.reshape(B*TT, self.hidden_size) logits, values = self.policy.decode_actions(hidden) + values = values.reshape(B, TT) state.hidden = hidden state.lstm_h = lstm_h state.lstm_c = lstm_c diff --git a/pufferlib/ocean/grid/grid.py b/pufferlib/ocean/grid/grid.py index 528aab6fc..7fc1560ee 100644 --- a/pufferlib/ocean/grid/grid.py +++ b/pufferlib/ocean/grid/grid.py @@ -9,7 +9,7 @@ class Grid(pufferlib.PufferEnv): def __init__(self, render_mode='raylib', vision_range=5, num_envs=4096, num_maps=1000, map_size=-1, max_map_size=9, - report_interval=128, buf=None): + report_interval=128, buf=None, seed=0): self.obs_size = 2*vision_range + 1 self.single_observation_space = gymnasium.spaces.Box(low=0, high=255, shape=(self.obs_size*self.obs_size,), dtype=np.uint8) From 8b8d0130c1b536a420da97f60096ba796930fcd3 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Wed, 9 Apr 2025 00:59:02 +0000 Subject: [PATCH 05/26] cleanups --- clean_pufferl.py | 246 ++++++++++++----------------------------- demo.py | 3 - pufferlib/models.py | 4 +- pufferlib/namespace.py | 5 + 4 files changed, 78 insertions(+), 180 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 2c5c2245f..0383d329c 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -125,23 +125,36 @@ def compute_advantages( return advantages def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): - seed_everything(config.seed, config.torch_deterministic) - losses = make_losses() + random.seed(config.seed) + np.random.seed(config.seed) + torch.backends.cudnn.deterministic = config.torch_deterministic + if config.seed is not None: + torch.manual_seed(config.seed) + + losses = pufferlib.namespace( + policy_loss=0, + value_loss=0, + entropy=0, + old_approx_kl=0, + approx_kl=0, + clipfrac=0, + explained_variance=0, + diayn_loss=0, + grad_var=0, + ) utilization = Utilization() msg = f'Model Size: {abbreviate(count_params(policy))} parameters' vecenv.async_reset(config.seed) + total_agents = vecenv.num_agents obs_shape = vecenv.single_observation_space.shape - obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[vecenv.single_observation_space.dtype] atn_shape = vecenv.single_action_space.shape + obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[vecenv.single_observation_space.dtype] atn_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[vecenv.single_action_space.dtype] - total_agents = vecenv.num_agents - on_policy_rows = config.batch_size // config.bptt_horizon off_policy_rows = int(config.replay_factor*config.batch_size // config.bptt_horizon) experience_rows = on_policy_rows + off_policy_rows - pin = config.device == 'cuda' and config.cpu_offload obs_device = config.device if not pin else 'cpu' experience = pufferlib.namespace( @@ -197,23 +210,10 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): lstm_c = torch.zeros(shape).to(config.device) minibatch_size = min(config.minibatch_size, config.max_minibatch_size) - num_minibatches = config.batch_size / minibatch_size - if num_minibatches != int(num_minibatches): - raise ValueError('batch_size must be divisible by minibatch_size') - else: - num_minibatches = int(num_minibatches) - - minibatch_rows = minibatch_size / config.bptt_horizon - if minibatch_rows != int(minibatch_rows): - raise ValueError('minibatch_size must be divisible by bptt_horizon') - else: - minibatch_rows = int(minibatch_rows) - uncompiled_policy = policy if config.compile: policy = torch.compile(policy, mode=config.compile_mode, fullgraph=config.compile_fullgraph) - assert config.optimizer in ('adam', 'muon', 'kron') if config.optimizer == 'adam': optimizer = torch.optim.Adam( policy.parameters(), @@ -241,6 +241,8 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): precond_lr=config.precond_lr, beta=config.adam_beta1, ) + else: + raise ValueError(f'Unknown optimizer: {config.optimizer}') epochs = config.total_timesteps // config.batch_size assert config.scheduler in ('linear', 'cosine') @@ -249,11 +251,11 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): elif config.scheduler == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) - scaler = None if config.precision == 'float32' else torch.amp.GradScaler() - amp_context = nullcontext() + scaler = None if config.precision != 'float32': amp_context = torch.amp.autocast(device_type='cuda', dtype=getattr(torch, config.precision)) + scaler = torch.amp.GradScaler() profile = Profile(amp_context) print_dashboard(config.env, utilization, 0, 0, profile, losses, {}, msg, clear=True) @@ -291,8 +293,6 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): step=0, lstm_h=lstm_h, lstm_c=lstm_c, - minibatch_rows=minibatch_rows, - num_minibatches=num_minibatches, stored_indices=stored_indices, ep_uses=ep_uses, ep_lengths=ep_lengths, @@ -302,6 +302,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): off_policy_rows=off_policy_rows, experience_rows=experience_rows, device=config.device, + minibatch_size=minibatch_size, ) @pufferlib.utils.profile @@ -315,7 +316,7 @@ def evaluate(data): lstm_h = data.lstm_h lstm_c = data.lstm_c - while not full(data): + while data.free_idx < data.on_policy_rows: with profile.env: o, r, d, t, info, env_id, mask = data.vecenv.recv() @@ -341,7 +342,6 @@ def evaluate(data): done_idxs = env_id[done_mask] experience.e3b_inv[done_idxs] = experience.e3b_orig[done_idxs] - o = torch.as_tensor(o) o_device = o.to(config.device, non_blocking=True) r = torch.as_tensor(r).to(config.device, non_blocking=True) @@ -411,7 +411,7 @@ def evaluate(data): lstm_c[:, gpu_env_id] = state.lstm_c o = o if config.cpu_offload else o_device - actions = store(data, state, o, o_device, value, action, logprob, r, d, gpu_env_id, mask) + actions = store(data, state, o, value, action, logprob, r, d, gpu_env_id, mask) with profile.eval_misc: for i in info: @@ -440,26 +440,23 @@ def evaluate(data): else: data.stats[k] += v - # TODO: Better way to enable multiple collects - data.ptr = 0 - data.step = 0 + data.free_idx = 0 + data.ep_lengths.zero_() + data.ep_uses.zero_() return data.stats, infos @pufferlib.utils.profile def train(data): config, profile, experience = data.config, data.profile, data.experience - losses = make_losses() - data.losses = losses - # TODO: Better place for this - data.free_idx = 0 - data.ep_lengths.zero_() - data.ep_uses.zero_() + losses = data.losses + for k in data.losses: + losses[k] = 0 - # Optimizing the policy and value network - total_minibatches = data.num_minibatches * config.update_epochs cross_entropy = torch.nn.CrossEntropyLoss() + total_minibatches = int(config.update_epochs*config.batch_size/data.minibatch_size) accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) + n_samples = config.minibatch_size // config.bptt_horizon for mb in range(total_minibatches): with profile.train_misc: if config.use_p3o: @@ -498,9 +495,7 @@ def train(data): advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) - n_samples = config.minibatch_size // config.bptt_horizon batch = sample(data, advantages, n_samples) - state = pufferlib.namespace( action=batch.actions, lstm_h=None, @@ -514,24 +509,17 @@ def train(data): if not isinstance(data.policy, torch.nn.LSTM): batch.obs = batch.obs.reshape(-1, *data.vecenv.single_observation_space.shape) - logits, newvalue = data.policy.forward_train(batch.obs, state) # TODO: Currently only returning traj shaped value as a hack + logits, newvalue = data.policy.forward_train(batch.obs, state) + with torch.no_grad(): experience.values[batch.idx] = newvalue lstm_h = state.lstm_h lstm_c = state.lstm_c - if lstm_h is not None: - lstm_h = lstm_h.detach() - if lstm_c is not None: - lstm_c = lstm_c.detach() - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=batch.actions, is_continuous=data.policy.is_continuous) - if config.device == 'cuda': - torch.cuda.synchronize() - with profile.train_misc: logratio = newlogprob - batch.logprobs.reshape(-1) ratio = logratio.exp() @@ -560,25 +548,10 @@ def train(data): newvalue_std = newvalue.std.view(-1, config.p3o_horizon) newvalue_var = torch.square(newvalue_std) criterion = torch.nn.GaussianNLLLoss(reduction='none') - #v_loss = criterion(newvalue_mean[:, :32], rew_block[:, :32], newvalue_var[:, :32]) v_loss = criterion(newvalue_mean, batch.reward_block, newvalue_var) v_loss = v_loss[:, :(horizon+3)] mask_block = mask_block[:, :(horizon+3)] - #v_loss[:, horizon:] = 0 - #v_loss = (v_loss * mask_block).sum(axis=1) - #v_loss = (v_loss - v_loss.mean().item()) / (v_loss.std().item() + 1e-8) - #v_loss = v_loss.mean() v_loss = v_loss[mask_block.bool()].mean() - #TODO: Count mask and sum - # There is going to have to be some sort of norm here. - # Right now, learning works at different horizons, but you need - # to retune hyperparameters. Ideally, horizon should be a stable - # param that zero-shots the same hypers - - # Faster than masking - #v_loss = (v_loss*mask_block[:, :32]).sum() / mask_block[:, :32].sum() - #v_loss = (v_loss*mask_block).sum() / mask_block.sum() - #v_loss = v_loss[mask_block.bool()].mean() elif config.clip_vloss: newvalue = newvalue.flatten() ret = batch.returns.flatten() @@ -601,20 +574,22 @@ def train(data): with profile.custom: if config.use_diayn: - diayn_discriminator = data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator + diayn_discriminator = (data.policy.diayn_discriminator if + hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator) q = diayn_discriminator(state.hidden).squeeze() diayn_loss = cross_entropy(q, z_idxs) loss += config.diayn_loss_coef*diayn_loss with profile.learn: - if data.scaler is None: - loss.backward() - else: - data.scaler.scale(loss).backward() + if data.scaler is not None: + loss = data.scaler.scale(loss) + + loss.backward() if data.scaler is not None: data.scaler.unscale_(data.optimizer) + # TODO: Delete? with torch.no_grad(): grads = torch.cat([p.grad.flatten() for p in data.policy.parameters()]) grad_var = grads.var(0).mean() * config.minibatch_size @@ -623,6 +598,7 @@ def train(data): if (mb + 1) % accumulate_minibatches == 0: torch.nn.utils.clip_grad_norm_(data.policy.parameters(), config.max_grad_norm) + # TODO: Can remove scaler if only using bf16 if data.scaler is None: data.optimizer.step() else: @@ -650,11 +626,9 @@ def train(data): # Reprioritize experience advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) - ep_uses = data.ep_uses data.max_uses = ep_uses.max().item() data.mean_uses = ep_uses.float().mean().item() - n_samples = data.off_policy_rows exp = sample(data, advantages, n_samples) for k, v in experience.items(): @@ -679,8 +653,6 @@ def train(data): data.epoch += 1 done_training = data.global_step >= config.total_timesteps - # TODO: beter way to get episode return update without clogging dashboard - # TODO: make this appear faster logs = None if done_training or profile.update(data): logs = mean_and_log(data) @@ -688,72 +660,44 @@ def train(data): profile, data.losses, data.stats, data.msg) data.stats = defaultdict(list) - #print('MEAN', experience.b_values_mean.mean(0).mean(0)) - #print('STD', torch.exp(experience.b_values_logstd).mean(0).mean(0)) - if data.epoch % config.checkpoint_interval == 0 or done_training: save_checkpoint(data) data.msg = f'Checkpoint saved at update {data.epoch}' return logs -def full(data): - return data.free_idx >= data.on_policy_rows - -def store(data, state, cpu_obs, gpu_obs, value, action, logprob, reward, done, env_id, mask): - # Mask learner and Ensure indices do not exceed batch size +def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): exp = data.experience - ptr = data.ptr - indices = np.where(mask)[0] - num_indices = indices.size - end = ptr + num_indices - dst = slice(ptr, end) - - # Zero-copy indexing for contiguous env_id - ''' - if num_indices == mask.size and isinstance(env_id, slice): - gpu_inds = cpu_inds = slice(0, min(self.batch_size - ptr, num_indices)) - else: - cpu_inds = indices[:self.batch_size - ptr] - gpu_inds = torch.as_tensor(cpu_inds).to(self.obs.device, non_blocking=True) - ''' - batch_rows = data.ep_indices[env_id] l = data.ep_lengths[env_id] - if exp.obs.device.type == 'cuda': - exp.obs[batch_rows, l] = gpu_obs - else: - exp.obs[batch_rows, l] = cpu_obs - if isinstance(env_id, slice): - data.stored_indices[batch_rows] = torch.arange(env_id.start, env_id.stop, device=data.device).int() - else: - data.stored_indices[batch_rows] = env_id + env_id = torch.arange(env_id.start, env_id.stop, device=data.device).int() + data.stored_indices[batch_rows] = env_id - if data.use_diayn: - data.diayn_batch[dst] = state.diayn_z_idxs[gpu_inds] + exp.obs[batch_rows, l] = obs + exp.actions[batch_rows, l] = action + exp.logprobs[batch_rows, l] = logprob + exp.rewards[batch_rows, l] = reward + exp.dones[batch_rows, l] = done.float() if data.use_p3o: - exp.values_mean[dst] = value.mean[gpu_inds] - exp.values_std[dst] = value.std[gpu_inds] + exp.values_mean[batch_rows, l] = value.mean + exp.values_std[batch_rows, l] = value.std else: exp.values[batch_rows, l] = value.flatten() - exp.actions[batch_rows, l] = action - exp.logprobs[batch_rows, l] = logprob - exp.rewards[batch_rows, l] = reward.to(exp.rewards.device) # ??? - exp.dones[batch_rows, l] = done.float().to(exp.dones.device) # ??? + if data.use_diayn: + data.diayn_batch[batch_rows] = state.diayn_z_idxs + # TODO: Handle masks!! + #indices = np.where(mask)[0] + #data.ep_lengths[env_id[mask]] += 1 l += 1 - data.ep_lengths[env_id] = l full = l >= data.config.bptt_horizon num_full = full.sum() if num_full > 0: - if isinstance(env_id, slice): - env_id = torch.arange(env_id.start, env_id.stop, device=data.device).int() - full_ids = env_id[full] data.ep_indices[full_ids] = data.free_idx + torch.arange(num_full, device=data.device).int() data.ep_lengths[full_ids] = 0 @@ -787,27 +731,6 @@ def sample(data, advantages, n, reward_block=None, mask_block=None): return pufferlib.namespace(**output) -def compute_pg_loss(log_probs, newlogprob, adv, clip_coef): - logratio = newlogprob - log_probs.reshape(-1) - ratio = logratio.exp() - - with torch.no_grad(): - # calculate approx_kl http://joschu.net/blog/kl-approx.html - old_approx_kl = (-logratio).mean() - approx_kl = ((ratio - 1) - logratio).mean() - clipfrac = ((ratio - 1.0).abs() > clip_coef).float().mean() - - adv = adv.view(-1) - adv = (adv - adv.mean()) / (adv.std() + 1e-8) - - # Policy loss - pg_loss1 = -adv * ratio - pg_loss2 = -adv * torch.clamp( - ratio, 1 - clip_coef, 1 + clip_coef - ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() - return pg_loss, approx_kl, old_approx_kl, clipfrac - def dist_sum(value, device): if not dist.is_initialized(): return value @@ -834,24 +757,17 @@ def mean_and_log(data): device = data.config.device - sps = dist_sum(data.profile.SPS, device) agent_steps = int(dist_sum(data.global_step, device)) - epoch = int(dist_sum(data.epoch, device)) - learning_rate = data.optimizer.param_groups[0]["lr"] - environment = {k: dist_mean(v, device) for k, v in data.stats.items()} - losses = {k: dist_mean(v, device) for k, v in data.losses.items()} - performance = {k: dist_sum(v, device) for k, v in data.profile} - logs = { - 'SPS': sps, + 'SPS': dist_sum(data.profile.SPS, device), 'agent_steps': agent_steps, - 'epoch': epoch, - 'learning_rate': learning_rate, + 'epoch': int(dist_sum(data.epoch, device)), + 'learning_rate': data.optimizer.param_groups[0]["lr"], 'max_uses': data.max_uses, 'mean_uses': data.mean_uses, - **{f'environment/{k}': v for k, v in environment.items()}, - **{f'losses/{k}': v for k, v in losses.items()}, - **{f'performance/{k}': v for k, v in performance.items()}, + **{f'environment/{k}': dist_mean(v, device) for k, v in data.stats.items()}, + **{f'losses/{k}': dist_mean(v, device) for k, v in data.losses.items()}, + **{f'performance/{k}': dist_sum(v, device) for k, v in data.profile}, } if dist.is_initialized() and dist.get_rank() != 0: @@ -957,19 +873,6 @@ def update(self, data, interval_s=1): self.custom_time = self.custom.elapsed return True -def make_losses(): - return pufferlib.namespace( - policy_loss=0, - value_loss=0, - entropy=0, - old_approx_kl=0, - approx_kl=0, - clipfrac=0, - explained_variance=0, - diayn_loss=0, - grad_var=0, - ) - class Utilization(Thread): def __init__(self, delay=1, maxlen=20): super().__init__() @@ -1010,7 +913,7 @@ def save_checkpoint(data): if os.path.exists(model_path): return model_path - torch.save(data.uncompiled_policy, model_path) + torch.save(data.uncompiled_policy.state_dict(), model_path) state = { 'optimizer_state_dict': data.optimizer.state_dict(), @@ -1035,7 +938,8 @@ def try_load_checkpoint(data): trainer_path = os.path.join(path, 'trainer_state.pt') resume_state = torch.load(trainer_path, weights_only=False) model_path = os.path.join(path, resume_state['model_name']) - data.policy.uncompiled.load_state_dict(model_path, map_location=config.device) + data.policy.uncompiled.load_state_dict( + torch.load(model_path, weights_only=True), map_location=config.device) data.optimizer.load_state_dict(resume_state['optimizer_state_dict']) print(f'Loaded checkpoint {resume_state["model_name"]}') @@ -1052,10 +956,9 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k # single-agent/multi-agent API for evaluation env = pufferlib.vector.make(env_creator, env_kwargs=env_kwargs, backend=backend) - if model_path is None: - agent = agent_creator(env, policy_cls, rnn_cls, agent_kwargs).to(device) - else: - agent = torch.load(model_path, map_location=device, weights_only=False) + agent = agent_creator(env, policy_cls, rnn_cls, agent_kwargs).to(device) + if model_path is not None: + agent.load_state_dict(torch.load(model_path, map_location=device, weights_only=False)) #e3b_inv = 10*torch.eye(agent.hidden_size).repeat(env_kwargs['num_envs'], 1, 1).to(device) e3b_inv = None @@ -1115,13 +1018,6 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k import imageio os.makedirs('../docker', exist_ok=True) or imageio.mimsave('../docker/eval.gif', frames, fps=15, loop=0) -def seed_everything(seed, torch_deterministic): - random.seed(seed) - np.random.seed(seed) - if seed is not None: - torch.manual_seed(seed) - torch.backends.cudnn.deterministic = torch_deterministic - ROUND_OPEN = rich.box.Box( "╭──╮\n" "│ │\n" diff --git a/demo.py b/demo.py index 2adf463f2..a75289e3c 100644 --- a/demo.py +++ b/demo.py @@ -177,9 +177,6 @@ def train(args, make_env, policy_cls, rnn_cls, target_metric, min_eval_points=10 batch_size = args['train']['batch_size'] while len(data.stats[target_metric]) < min_eval_points: stats, _ = clean_pufferl.evaluate(data) - # TODO: Beter place for this - data.free_idx = 0 - data.ep_lengths.zero_() steps_evaluated += batch_size clean_pufferl.mean_and_log(data) diff --git a/pufferlib/models.py b/pufferlib/models.py index 627cfdefd..b7e9a873e 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -210,8 +210,8 @@ def forward_train(self, observations, state): logits, values = self.policy.decode_actions(hidden) values = values.reshape(B, TT) state.hidden = hidden - state.lstm_h = lstm_h - state.lstm_c = lstm_c + state.lstm_h = lstm_h.detach() + state.lstm_c = lstm_c.detach() return logits, values class Convolutional(nn.Module): diff --git a/pufferlib/namespace.py b/pufferlib/namespace.py index dff138f64..a6ecfe352 100644 --- a/pufferlib/namespace.py +++ b/pufferlib/namespace.py @@ -5,6 +5,9 @@ def __getitem__(self, key): return self.__dict__[key] +def __setitem__(self, key, value): + self.__dict__[key] = value + def keys(self): return self.__dict__.keys() @@ -22,6 +25,7 @@ def __len__(self): class Namespace(SimpleNamespace, Mapping): __getitem__ = __getitem__ + __setitem__ = __setitem__ __iter__ = __iter__ __len__ = __len__ keys = keys @@ -42,6 +46,7 @@ def __init__(self, **kwargs): cls.__init__ = __init__ setattr(cls, "__getitem__", __getitem__) + setattr(cls, "__setitem__", __setitem__) setattr(cls, "__iter__", __iter__) setattr(cls, "__len__", __len__) setattr(cls, "keys", keys) From 7ecb95b7b6f0be9cdf54ff7202b3b87ce2c747ab Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Wed, 9 Apr 2025 16:42:41 +0000 Subject: [PATCH 06/26] Add cuda checks --- clean_pufferl.py | 30 +++++++++++++++++++++--------- pufferlib.cu | 43 +++++++++++++++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 0383d329c..619042803 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -36,7 +36,9 @@ sources=['pufferlib.cu'], verbose=True ) +compute_gae = cuda_module.compute_gae +''' def compute_gae( values: torch.Tensor, # [num_steps, horizon] rewards: torch.Tensor, # [num_steps, horizon] @@ -70,6 +72,7 @@ def compute_gae( torch.cuda.synchronize() return advantages +''' def compute_advantages( @@ -103,6 +106,7 @@ def compute_advantages( # Launch kernel threads_per_block = 256 + assert num_steps % threads_per_block == 0 blocks = (num_steps + threads_per_block - 1) // threads_per_block cuda_module.advantage_kernel( @@ -625,14 +629,16 @@ def train(data): break # Reprioritize experience - advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) ep_uses = data.ep_uses data.max_uses = ep_uses.max().item() data.mean_uses = ep_uses.float().mean().item() - n_samples = data.off_policy_rows - exp = sample(data, advantages, n_samples) - for k, v in experience.items(): - v[data.on_policy_rows:] = exp[k] + if config.replay_factor > 0: + advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) + n_samples = data.off_policy_rows + exp = sample(data, advantages, n_samples, method='topk') + for k, v in experience.items(): + v[data.on_policy_rows:] = exp[k] + with profile.train_misc: if config.anneal_lr: @@ -706,11 +712,17 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): data.step += 1 return action.cpu().numpy() -def sample(data, advantages, n, reward_block=None, mask_block=None): +def sample(data, advantages, n, reward_block=None, mask_block=None, method='multinomial'): exp = data.experience - #idx = torch.multinomial(advantages.abs().sum(axis=1), n) - idx = torch.randint(0, advantages.shape[0], (n,), device=data.device) - #_, idx = torch.topk(advantages.abs().sum(axis=1), n) + if method == 'topk': + _, idx = torch.topk(advantages.abs().sum(axis=1), n) + elif method == 'multinomial': + idx = torch.multinomial(advantages.abs().sum(axis=1) + 1e-6, n) + elif method == 'random': + idx = torch.randint(0, advantages.shape[0], (n,), device=data.device) + else: + raise ValueError(f'Unknown sampling method: {method}') + data.ep_uses[idx] += 1 output = {k: v[idx] for k, v in exp.items()} output['idx'] = idx diff --git a/pufferlib.cu b/pufferlib.cu index f56ef82ce..c721fe44c 100644 --- a/pufferlib.cu +++ b/pufferlib.cu @@ -43,15 +43,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { } }, "Compute p3o advantages with CUDA"); - m.def("compute_gae", [](torch::Tensor values, - torch::Tensor rewards, - torch::Tensor dones, - torch::Tensor advantages, - float gamma, - float gae_lambda, - int num_steps, - int horizon) { - // Launch the kernel + m.def("compute_gae", []( + torch::Tensor values, + torch::Tensor rewards, + torch::Tensor dones, + float gamma, + float gae_lambda) { + torch::Device device = values.device(); + int num_steps = values.size(0); + int horizon = values.size(1); + + // Validate input tensors + for (const torch::Tensor& t : {values, rewards, dones}) { + TORCH_CHECK(t.dim() == 2, "Tensor must be 2D"); + TORCH_CHECK(t.device() == device, "All tensors must be on same device"); + TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); + TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon"); + TORCH_CHECK(t.is_cuda(), "All tensors must be on GPU"); + TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); + if (!t.is_contiguous()) { + t.contiguous(); + } + } + + torch::Tensor advantages = torch::zeros( + {num_steps, horizon}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(device) + ); + int threads_per_block = 256; int blocks = (num_steps + threads_per_block - 1) / threads_per_block; @@ -65,10 +86,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { num_steps, horizon ); - // Check for CUDA errors + cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { throw std::runtime_error(cudaGetErrorString(err)); } + + return advantages; }, "Compute GAE with CUDA"); } From 4e64d92eee44c524795fffefb7ad64f61d57b422 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Wed, 9 Apr 2025 17:25:34 +0000 Subject: [PATCH 07/26] GPU kernel + CPU fallback for GAE --- clean_pufferl.py | 28 +++--- pufferlib.cpp | 26 +++++ pufferlib.cu | 250 ++++++++++++++++++++++++++++++----------------- shared.cpp | 47 +++++++++ 4 files changed, 244 insertions(+), 107 deletions(-) create mode 100644 pufferlib.cpp create mode 100644 shared.cpp diff --git a/clean_pufferl.py b/clean_pufferl.py index 619042803..c9638cc25 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -15,6 +15,7 @@ from rich.table import Table import torch import torch.distributed as dist +from torch.utils.cpp_extension import load import pufferlib import pufferlib.utils @@ -22,21 +23,6 @@ torch.set_float32_matmul_precision('high') -# Fast Cython advantage functions -#from c_advantage import rewards_and_masks, compute_gae -#from c_advantage import compute_gae - -import torch -from torch.utils.cpp_extension import load - - -# Compile the CUDA kernel -cuda_module = load( - name='compute_gae', - sources=['pufferlib.cu'], - verbose=True -) -compute_gae = cuda_module.compute_gae ''' def compute_gae( @@ -135,6 +121,13 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): if config.seed is not None: torch.manual_seed(config.seed) + ext = 'cu' if 'cuda' in config.device else 'cpp' + compute_gae = load( + name='compute_gae', + sources=[f'pufferlib.{ext}'], + verbose=True + ).compute_gae + losses = pufferlib.namespace( policy_loss=0, value_loss=0, @@ -307,6 +300,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): experience_rows=experience_rows, device=config.device, minibatch_size=minibatch_size, + compute_gae=compute_gae, ) @pufferlib.utils.profile @@ -496,7 +490,7 @@ def train(data): advantages = advantages.cpu().numpy() torch.cuda.synchronize() else: - advantages = compute_gae(experience.values, experience.rewards, + advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) batch = sample(data, advantages, n_samples) @@ -633,7 +627,7 @@ def train(data): data.max_uses = ep_uses.max().item() data.mean_uses = ep_uses.float().mean().item() if config.replay_factor > 0: - advantages = compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) + advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) n_samples = data.off_policy_rows exp = sample(data, advantages, n_samples, method='topk') for k, v in experience.items(): diff --git a/pufferlib.cpp b/pufferlib.cpp new file mode 100644 index 000000000..10e39fcbb --- /dev/null +++ b/pufferlib.cpp @@ -0,0 +1,26 @@ +#include "shared.cpp" + +// [num_steps, horizon] +void gae(float* values, float* rewards, float* dones, float* advantages, + float gamma, float gae_lambda, int num_steps, int horizon){ + for (int offset = 0; offset < num_steps*horizon; offset+=horizon) { + gae_row(values + offset, rewards + offset, dones + offset, + advantages + offset, gamma, gae_lambda, horizon); + } +} + +torch::Tensor compute_gae(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, float gamma, float gae_lambda) { + int num_steps = values.size(0); + int horizon = values.size(1); + torch::Tensor advantages = gae_check(values, rewards, dones, num_steps, horizon); + gae(values.data_ptr(), rewards.data_ptr(), + dones.data_ptr(), advantages.data_ptr(), + gamma, gae_lambda, num_steps, horizon + ); + return advantages; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compute_gae", &compute_gae, "Compute GAE with C"); +} diff --git a/pufferlib.cu b/pufferlib.cu index c721fe44c..917dd5eb4 100644 --- a/pufferlib.cu +++ b/pufferlib.cu @@ -1,97 +1,167 @@ -#include -#include "c_advantage.cu" +#include "shared.cpp" -// Pybind11 module definition -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("compute_p3o", [](torch::Tensor reward_block, - torch::Tensor reward_mask, - torch::Tensor values_mean, - torch::Tensor values_std, - torch::Tensor buf, - torch::Tensor dones, - torch::Tensor rewards, - torch::Tensor advantages, - torch::Tensor bounds, - int num_steps, - float vstd_max, - float puf, - int horizon) { - // Launch the kernel - int threads_per_block = 256; - int blocks = (num_steps + threads_per_block - 1) / threads_per_block; - - p3o_kernel<<>>( - reward_block.data_ptr(), - reward_mask.data_ptr(), - values_mean.data_ptr(), - values_std.data_ptr(), - buf.data_ptr(), - dones.data_ptr(), - rewards.data_ptr(), - advantages.data_ptr(), - bounds.data_ptr(), - num_steps, - vstd_max, - puf, - horizon - ); - - // Check for CUDA errors - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - throw std::runtime_error(cudaGetErrorString(err)); +__global__ void p3o_kernel( + float* reward_block, // [num_steps, horizon] + float* reward_mask, // [num_steps, horizon] + float* values_mean, // [num_steps, horizon] + float* values_std, // [num_steps, horizon] + float* buf, // [num_steps, horizon] + float* dones, // [num_steps] + float* rewards, // [num_steps] + float* advantages, // [num_steps] + int* bounds, // [num_steps] + int num_steps, + float r_std, + float puf, + int horizon +) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= num_steps) return; + + int k = 0; + for (int j = 0; j < horizon-1; j++) { + int t = i + j; + if (t >= num_steps - 1) { + break; } - }, "Compute p3o advantages with CUDA"); - - m.def("compute_gae", []( - torch::Tensor values, - torch::Tensor rewards, - torch::Tensor dones, - float gamma, - float gae_lambda) { - torch::Device device = values.device(); - int num_steps = values.size(0); - int horizon = values.size(1); - - // Validate input tensors - for (const torch::Tensor& t : {values, rewards, dones}) { - TORCH_CHECK(t.dim() == 2, "Tensor must be 2D"); - TORCH_CHECK(t.device() == device, "All tensors must be on same device"); - TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); - TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon"); - TORCH_CHECK(t.is_cuda(), "All tensors must be on GPU"); - TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); - if (!t.is_contiguous()) { - t.contiguous(); - } + if (dones[t+1]) { + k++; + break; } + k++; + } - torch::Tensor advantages = torch::zeros( - {num_steps, horizon}, - torch::TensorOptions() - .dtype(torch::kFloat32) - .device(device) - ); - - int threads_per_block = 256; - int blocks = (num_steps + threads_per_block - 1) / threads_per_block; - - gae_kernel<<>>( - values.data_ptr(), - rewards.data_ptr(), - dones.data_ptr(), - advantages.data_ptr(), - gamma, - gae_lambda, - num_steps, - horizon - ); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - throw std::runtime_error(cudaGetErrorString(err)); + float gamma_max = 0.0f; + float n = 0.0f; + for (int j = k-1; j >= 0; j--) { + int idx = i * horizon + j; + n++; + + float vstd = values_std[idx]; + if (vstd == 0.0f) { + buf[idx] = 0.0f; + continue; } - return advantages; - }, "Compute GAE with CUDA"); + float gamma = 1.0f / (vstd*vstd); + if (r_std != 0.0f) { + gamma -= puf/(r_std*r_std); + } + + if (gamma < 0.0f) { + gamma = 0.0f; + } + + if (gamma > gamma_max) { + gamma_max = gamma; + } + buf[idx] = gamma; + reward_mask[idx] = 1.0f; + } + + //float bootstrap = 0.0f; + //if (k == horizon-1) { + // bootstrap = buf[i*horizon + horizon - 1]*values_mean[i*horizon + horizon - 1]; + //} + + float R = 0.0f; + for (int j = 0; j <= k-1; j++) { + int t = i + j; + int idx = i * horizon + j; + float r = rewards[t+1]; + + float gamma = buf[idx]; + if (gamma_max > 0) { + gamma /= gamma_max; + } + + if (j >= 16 && values_std[idx] > 0.95*r_std) { + break; + } + + R += gamma * (r - values_mean[idx]); + reward_block[idx] = r; + buf[idx] = gamma; + } + + advantages[i] = R; + bounds[i] = k; +} + + +// [num_steps, horizon] +__global__ void gae_kernel(float* values, float* rewards, float* dones, + float* advantages, float gamma, float gae_lambda, int num_steps, int horizon) { + int row = blockIdx.x*blockDim.x + threadIdx.x; + int offset = row*horizon; + gae_row(values + offset, rewards + offset, dones + offset, + advantages + offset, gamma, gae_lambda, horizon); +} + +void compute_p3o(torch::Tensor reward_block, torch::Tensor reward_mask, + torch::Tensor values_mean, torch::Tensor values_std, torch::Tensor buf, + torch::Tensor dones, torch::Tensor rewards, torch::Tensor advantages, + torch::Tensor bounds, int num_steps, float vstd_max, float puf, + int horizon) { + // Launch the kernel + int threads_per_block = 256; + int blocks = (num_steps + threads_per_block - 1) / threads_per_block; + + p3o_kernel<<>>( + reward_block.data_ptr(), + reward_mask.data_ptr(), + values_mean.data_ptr(), + values_std.data_ptr(), + buf.data_ptr(), + dones.data_ptr(), + rewards.data_ptr(), + advantages.data_ptr(), + bounds.data_ptr(), + num_steps, + vstd_max, + puf, + horizon + ); + + // Check for CUDA errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error(cudaGetErrorString(err)); + } + return; +} + +torch::Tensor compute_gae(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, float gamma, float gae_lambda) { + int num_steps = values.size(0); + int horizon = values.size(1); + torch::Tensor advantages = gae_check(values, rewards, dones, num_steps, horizon); + TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU"); + + int threads_per_block = 256; + int blocks = (num_steps + threads_per_block - 1) / threads_per_block; + + gae_kernel<<>>( + values.data_ptr(), + rewards.data_ptr(), + dones.data_ptr(), + advantages.data_ptr(), + gamma, + gae_lambda, + num_steps, + horizon + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error(cudaGetErrorString(err)); + } + + return advantages; +} + +// Pybind11 module definition +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compute_p3o", &compute_p3o, "Compute p3o advantages with CUDA"); + m.def("compute_gae", &compute_gae, "Compute GAE with CUDA"); } diff --git a/shared.cpp b/shared.cpp new file mode 100644 index 000000000..fd3202cb5 --- /dev/null +++ b/shared.cpp @@ -0,0 +1,47 @@ +#include + +// TODO: Find a better way to do conditional compilation +#ifndef __CUDA_ARCH__ +#define __host__ +#define __device__ +#endif + +// [horizon] +__host__ __device__ void gae_row(float* values, float* rewards, float* dones, float* advantages, + float gamma, float gae_lambda, int horizon) { + float lastgaelam = 0; + for (int t = horizon-2; t >= 0; t--) { + int t_next = t + 1; + float nextnonterminal = 1.0 - dones[t_next]; + float delta = rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]; + lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam; + advantages[t] = lastgaelam; + } +} + +torch::Tensor gae_check(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, int num_steps, int horizon) { + + // Validate input tensors + torch::Device device = values.device(); + for (const torch::Tensor& t : {values, rewards, dones}) { + TORCH_CHECK(t.dim() == 2, "Tensor must be 2D"); + TORCH_CHECK(t.device() == device, "All tensors must be on same device"); + TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); + TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon"); + TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); + if (!t.is_contiguous()) { + t.contiguous(); + } + } + + torch::Tensor advantages = torch::zeros( + {num_steps, horizon}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(device) + ); + return advantages; +} + + From 17b52aa0b4f06636cd0cc2c5dabda347df7742e2 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Wed, 9 Apr 2025 18:46:47 +0000 Subject: [PATCH 08/26] Initial e3b/diayn running on new version --- clean_pufferl.py | 390 +++++++++++++++++--------------------------- pufferlib.cu | 24 +++ pufferlib/models.py | 4 +- pufferlib/utils.py | 10 +- 4 files changed, 186 insertions(+), 242 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index c9638cc25..6289244cc 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -21,103 +21,11 @@ import pufferlib.utils import pufferlib.pytorch -torch.set_float32_matmul_precision('high') - - -''' -def compute_gae( - values: torch.Tensor, # [num_steps, horizon] - rewards: torch.Tensor, # [num_steps, horizon] - dones: torch.Tensor, # [num_steps, horizon] - gamma: float, - gae_lambda: float, - ): - - num_steps = values.shape[0] - horizon = values.shape[1] - advantages = torch.zeros(num_steps, horizon, dtype=torch.float32, device=values.device) - - for t in [values, rewards, dones, advantages]: - assert t.ndim == 2 - assert t.shape[0] == num_steps - assert t.shape[1] == horizon - t.contiguous() - assert t.is_cuda, "All tensors must be on GPU" - - - cuda_module.compute_gae( - values, - rewards, - dones, - advantages, - gamma, - gae_lambda, - num_steps, - horizon, - ) - - torch.cuda.synchronize() - return advantages -''' - - -def compute_advantages( - reward_block: torch.Tensor, # [num_steps, horizon] - reward_mask: torch.Tensor, # [num_steps, horizon] - values_mean: torch.Tensor, # [num_steps, horizon] - values_std: torch.Tensor, # [num_steps, horizon] - buf: torch.Tensor, # [num_steps, horizon] - dones: torch.Tensor, # [num_steps] - rewards: torch.Tensor, # [num_steps] - advantages: torch.Tensor, # [num_steps] - bounds: torch.Tensor, # [num_steps] - vstd_max: float, - puf: float, - horizon: int -): - assert all(t.is_cuda for t in [reward_block, reward_mask, values_mean, values_std, - buf, dones, rewards, advantages, bounds]), "All tensors must be on GPU" - - # Ensure contiguous memory - tensors = [reward_block, reward_mask, values_mean, values_std, buf, dones, rewards, advantages, bounds] - for t in tensors: - t.contiguous() - assert t.is_cuda - - num_steps = rewards.shape[0] - - # Precompute vstd_min and vstd_max - #vstd_max = values_std.max().item() - #vstd_min = values_std.min().item() - - # Launch kernel - threads_per_block = 256 - assert num_steps % threads_per_block == 0 - blocks = (num_steps + threads_per_block - 1) // threads_per_block - - cuda_module.advantage_kernel( - reward_block, - reward_mask, - values_mean, - values_std, - buf, - dones, - rewards, - advantages, - bounds, - num_steps, - vstd_max, - puf, - horizon, - ) - - torch.cuda.synchronize() - return advantages - def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): random.seed(config.seed) np.random.seed(config.seed) torch.backends.cudnn.deterministic = config.torch_deterministic + torch.set_float32_matmul_precision('high') if config.seed is not None: torch.manual_seed(config.seed) @@ -169,18 +77,21 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): ep_lengths = torch.zeros(total_agents, device=config.device, dtype=torch.int32) ep_indices = torch.arange(total_agents, device=config.device, dtype=torch.int32) free_idx = total_agents - assert free_idx <= experience_rows + + e3b_inv = None + e3b_orig = None if config.use_e3b: - experience.e3b_inv = torch.eye(policy.hidden_size).repeat(total_agents, 1, 1).to(config.device) / config.e3b_lambda - experience.e3b_orig = experience.e3b_inv.clone() - experience.e3b_mean = None - experience.e3b_std = None + e3b_inv = torch.eye(policy.hidden_size).repeat(total_agents, 1, 1).to(config.device) / config.e3b_lambda + e3b_orig = e3b_inv.clone() + diayn_archive = None if config.use_diayn: # TODO: Check shapes - experience.diayn_archive = torch.nn.functional.one_hot(torch.arange(config.diayn_archive), config.diayn_archive).to(config.device).float() - experience.diayn_skills = torch.randint(0, config.diayn_archive, (total_agents,), dtype=torch.long, device=config.device) + diayn_archive = torch.nn.functional.one_hot( + torch.arange(config.diayn_archive), config.diayn_archive).to(config.device).float() + experience.diayn_skills = torch.randint( + 0, config.diayn_archive, (experience_rows,), dtype=torch.long, device=config.device) experience.diayn_batch = torch.zeros(experience_rows, dtype=torch.long, device=config.device) if config.use_p3o: @@ -283,7 +194,6 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): e3b_norm=config.e3b_norm, puf=config.puf, use_diayn=config.use_diayn, - diayn_archive=config.diayn_archive, diayn_coef=config.diayn_coef, # Do we use these? ptr=0, @@ -301,6 +211,11 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): device=config.device, minibatch_size=minibatch_size, compute_gae=compute_gae, + e3b_inv=e3b_inv, + e3b_orig=e3b_orig, + e3b_mean=None, + e3b_std=None, + diayn_archive=diayn_archive, ) @pufferlib.utils.profile @@ -332,13 +247,13 @@ def evaluate(data): if data.use_diayn: idxs = env_id[done_mask] if len(idxs) > 0: - z_idxs = torch.randint(0, experience.diayn_archive.shape[0], (done_mask.sum(),)).to(config.device) + z_idxs = torch.randint(0, data.diayn_archive.shape[0], (done_mask.sum(),)).to(config.device) experience.diayn_skills[idxs] = z_idxs with profile.eval_copy: if data.use_e3b and done_mask.any(): done_idxs = env_id[done_mask] - experience.e3b_inv[done_idxs] = experience.e3b_orig[done_idxs] + data.e3b_inv[done_idxs] = data.e3b_orig[done_idxs] o = torch.as_tensor(o) o_device = o.to(config.device, non_blocking=True) @@ -363,7 +278,7 @@ def evaluate(data): if data.use_diayn: z_idxs = experience.diayn_skills[env_id] - z = experience.diayn_archive[z_idxs] + z = data.diayn_archive[z_idxs] state.diayn_z_idxs = z_idxs state.diayn_z = z @@ -379,24 +294,24 @@ def evaluate(data): state.diayn_z_idxs = z_idxs if data.use_e3b: - e3b = experience.e3b_inv[env_id] + e3b = data.e3b_inv[env_id] phi = state.hidden.detach() u = phi.unsqueeze(1) @ e3b b = u @ phi.unsqueeze(2) - experience.e3b_inv[env_id] -= (u.mT @ u) / (1 + b) + data.e3b_inv[env_id] -= (u.mT @ u) / (1 + b) done_inds = env_id[done_mask] - experience.e3b_inv[done_inds] = experience.e3b_orig[done_inds] + data.e3b_inv[done_inds] = data.e3b_orig[done_inds] e3b_reward = b.squeeze() - if experience.e3b_mean is None: - experience.e3b_mean = e3b_reward.mean() - experience.e3b_std = e3b_reward.std() + if data.e3b_mean is None: + data.e3b_mean = e3b_reward.mean() + data.e3b_std = e3b_reward.std() else: w = data.e3b_norm - experience.e3b_mean = (1-w)*e3b_reward.mean() + w*experience.e3b_mean - experience.e3b_std = (1-w)*e3b_reward.std() + w*experience.e3b_std + data.e3b_mean = (1-w)*e3b_reward.mean() + w*data.e3b_mean + data.e3b_std = (1-w)*e3b_reward.std() + w*data.e3b_std - e3b_reward = (e3b_reward - experience.e3b_mean) / (experience.e3b_std + 1e-6) + e3b_reward = (e3b_reward - data.e3b_mean) / (data.e3b_std + 1e-6) e3b_reward = config.e3b_coef*e3b_reward r += e3b_reward @@ -445,16 +360,14 @@ def evaluate(data): @pufferlib.utils.profile def train(data): - config, profile, experience = data.config, data.profile, data.experience - + config = data.config + profile = data.profile + experience = data.experience losses = data.losses - for k in data.losses: - losses[k] = 0 - cross_entropy = torch.nn.CrossEntropyLoss() total_minibatches = int(config.update_epochs*config.batch_size/data.minibatch_size) accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) - n_samples = config.minibatch_size // config.bptt_horizon + n_samples = data.minibatch_size // config.bptt_horizon for mb in range(total_minibatches): with profile.train_misc: if config.use_p3o: @@ -493,7 +406,10 @@ def train(data): advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) + with profile.train_copy: batch = sample(data, advantages, n_samples) + + with profile.train_misc: state = pufferlib.namespace( action=batch.actions, lstm_h=None, @@ -501,7 +417,7 @@ def train(data): ) if config.use_diayn: - z_idxs = batch.diayn_z_idxs + state.z_idxs = batch.diayn_z_idxs with profile.train_forward: if not isinstance(data.policy, torch.nn.LSTM): @@ -513,8 +429,6 @@ def train(data): with torch.no_grad(): experience.values[batch.idx] = newvalue - lstm_h = state.lstm_h - lstm_c = state.lstm_c actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=batch.actions, is_continuous=data.policy.is_continuous) @@ -575,7 +489,9 @@ def train(data): diayn_discriminator = (data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator) q = diayn_discriminator(state.hidden).squeeze() - diayn_loss = cross_entropy(q, z_idxs) + z_idxs = state.z_idxs.unsqueeze(1).expand(q.shape[:2]).reshape(-1) + q = q.view(-1, q.shape[-1]) + diayn_loss = torch.nn.functional.cross_entropy(q, z_idxs) loss += config.diayn_loss_coef*diayn_loss with profile.learn: @@ -623,13 +539,12 @@ def train(data): break # Reprioritize experience - ep_uses = data.ep_uses - data.max_uses = ep_uses.max().item() - data.mean_uses = ep_uses.float().mean().item() + data.max_uses = data.ep_uses.max().item() + data.mean_uses = data.ep_uses.float().mean().item() if config.replay_factor > 0: - advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) - n_samples = data.off_policy_rows - exp = sample(data, advantages, n_samples, method='topk') + advantages = data.compute_gae(experience.values, experience.rewards, + experience.dones, config.gamma, config.gae_lambda) + exp = sample(data, advantages, data.off_policy_rows, method='topk') for k, v in experience.items(): v[data.on_policy_rows:] = exp[k] @@ -652,14 +567,17 @@ def train(data): #losses.explained_variance = explained_var.item() data.epoch += 1 - done_training = data.global_step >= config.total_timesteps logs = None + done_training = data.global_step >= config.total_timesteps if done_training or profile.update(data): logs = mean_and_log(data) print_dashboard(config.env, data.utilization, data.global_step, data.epoch, profile, data.losses, data.stats, data.msg) data.stats = defaultdict(list) + for k in losses: + losses[k] = 0 + if data.epoch % config.checkpoint_interval == 0 or done_training: save_checkpoint(data) data.msg = f'Checkpoint saved at update {data.epoch}' @@ -689,7 +607,7 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): exp.values[batch_rows, l] = value.flatten() if data.use_diayn: - data.diayn_batch[batch_rows] = state.diayn_z_idxs + exp.diayn_batch[batch_rows] = state.diayn_z_idxs # TODO: Handle masks!! #indices = np.where(mask)[0] @@ -803,111 +721,6 @@ def close(data): elif data.neptune is not None: data.neptune.stop() -class Profile: - SPS: ... = 0 - uptime: ... = 0 - remaining: ... = 0 - eval_time: ... = 0 - env_time: ... = 0 - eval_forward_time: ... = 0 - eval_copy_time: ... = 0 - eval_misc_time: ... = 0 - train_time: ... = 0 - train_forward_time: ... = 0 - learn_time: ... = 0 - train_copy_time: ... = 0 - train_misc_time: ... = 0 - custom_time: ... = 0 - def __init__(self, amp_context): - self.start = time.time() - self.env = pufferlib.utils.Profiler() - # TODO: Figure out which of these need amp - self.eval_forward = pufferlib.utils.Profiler(amp_context=amp_context) - self.eval_copy = pufferlib.utils.Profiler(amp_context=amp_context) - self.eval_misc = pufferlib.utils.Profiler() - self.train_forward = pufferlib.utils.Profiler(amp_context=amp_context) - self.learn = pufferlib.utils.Profiler() - self.train_copy = pufferlib.utils.Profiler(amp_context=amp_context) - self.train_misc = pufferlib.utils.Profiler() - self.custom = pufferlib.utils.Profiler() - self.prev_steps = 0 - - def __iter__(self): - yield 'SPS', self.SPS - yield 'uptime', self.uptime - yield 'remaining', self.remaining - yield 'eval_time', self.eval_time - yield 'env_time', self.env_time - yield 'eval_forward_time', self.eval_forward_time - yield 'eval_copy_time', self.eval_copy_time - yield 'eval_misc_time', self.eval_misc_time - yield 'train_time', self.train_time - yield 'train_forward_time', self.train_forward_time - yield 'learn_time', self.learn_time - yield 'train_copy_time', self.train_copy_time - yield 'train_misc_time', self.train_misc_time - yield 'custom_time', self.custom_time - - @property - def epoch_time(self): - return self.train_time + self.eval_time - - def update(self, data, interval_s=1): - global_step = data.global_step - if global_step == 0: - return True - - uptime = time.time() - self.start - if uptime - self.uptime < interval_s: - return False - - self.SPS = (global_step - self.prev_steps) / (uptime - self.uptime) - self.prev_steps = global_step - self.uptime = uptime - - self.remaining = (data.config.total_timesteps - global_step) / self.SPS - self.eval_time = data._timers['evaluate'].elapsed - self.eval_forward_time = self.eval_forward.elapsed - self.env_time = self.env.elapsed - self.eval_copy_time = self.eval_copy.elapsed - self.eval_misc_time = self.eval_misc.elapsed - self.train_time = data._timers['train'].elapsed - self.train_forward_time = self.train_forward.elapsed - self.learn_time = self.learn.elapsed - self.train_copy_time = self.train_copy.elapsed - self.train_misc_time = self.train_misc.elapsed - self.custom_time = self.custom.elapsed - return True - -class Utilization(Thread): - def __init__(self, delay=1, maxlen=20): - super().__init__() - self.cpu_mem = deque(maxlen=maxlen) - self.cpu_util = deque(maxlen=maxlen) - self.gpu_util = deque(maxlen=maxlen) - self.gpu_mem = deque(maxlen=maxlen) - - self.delay = delay - self.stopped = False - self.start() - - def run(self): - while not self.stopped: - self.cpu_util.append(100*psutil.cpu_percent()) - mem = psutil.virtual_memory() - self.cpu_mem.append(100*mem.active/mem.total) - if torch.cuda.is_available(): - self.gpu_util.append(torch.cuda.utilization()) - free, total = torch.cuda.mem_get_info() - self.gpu_mem.append(100*free/total) - else: - self.gpu_util.append(0) - self.gpu_mem.append(0) - time.sleep(self.delay) - - def stop(self): - self.stopped = True - def save_checkpoint(data): config = data.config path = os.path.join(config.data_dir, config.exp_id) @@ -1024,6 +837,111 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k import imageio os.makedirs('../docker', exist_ok=True) or imageio.mimsave('../docker/eval.gif', frames, fps=15, loop=0) +class Profile: + SPS: ... = 0 + uptime: ... = 0 + remaining: ... = 0 + eval_time: ... = 0 + env_time: ... = 0 + eval_forward_time: ... = 0 + eval_copy_time: ... = 0 + eval_misc_time: ... = 0 + train_time: ... = 0 + train_forward_time: ... = 0 + learn_time: ... = 0 + train_copy_time: ... = 0 + train_misc_time: ... = 0 + custom_time: ... = 0 + def __init__(self, amp_context): + self.start = time.time() + self.env = pufferlib.utils.Profiler() + # TODO: Figure out which of these need amp + self.eval_forward = pufferlib.utils.Profiler(amp_context=amp_context) + self.eval_copy = pufferlib.utils.Profiler(amp_context=amp_context) + self.eval_misc = pufferlib.utils.Profiler() + self.train_forward = pufferlib.utils.Profiler(amp_context=amp_context) + self.learn = pufferlib.utils.Profiler() + self.train_copy = pufferlib.utils.Profiler(amp_context=amp_context) + self.train_misc = pufferlib.utils.Profiler() + self.custom = pufferlib.utils.Profiler() + self.prev_steps = 0 + + def __iter__(self): + yield 'SPS', self.SPS + yield 'uptime', self.uptime + yield 'remaining', self.remaining + yield 'eval_time', self.eval_time + yield 'env_time', self.env_time + yield 'eval_forward_time', self.eval_forward_time + yield 'eval_copy_time', self.eval_copy_time + yield 'eval_misc_time', self.eval_misc_time + yield 'train_time', self.train_time + yield 'train_forward_time', self.train_forward_time + yield 'learn_time', self.learn_time + yield 'train_copy_time', self.train_copy_time + yield 'train_misc_time', self.train_misc_time + yield 'custom_time', self.custom_time + + @property + def epoch_time(self): + return self.train_time + self.eval_time + + def update(self, data, interval_s=1): + global_step = data.global_step + if global_step == 0: + return True + + uptime = time.time() - self.start + if uptime - self.uptime < interval_s: + return False + + self.SPS = (global_step - self.prev_steps) / (uptime - self.uptime) + self.prev_steps = global_step + self.uptime = uptime + + self.remaining = (data.config.total_timesteps - global_step) / self.SPS + self.eval_time = data._timers['evaluate'].elapsed + self.eval_forward_time = self.eval_forward.elapsed + self.env_time = self.env.elapsed + self.eval_copy_time = self.eval_copy.elapsed + self.eval_misc_time = self.eval_misc.elapsed + self.train_time = data._timers['train'].elapsed + self.train_forward_time = self.train_forward.elapsed + self.learn_time = self.learn.elapsed + self.train_copy_time = self.train_copy.elapsed + self.train_misc_time = self.train_misc.elapsed + self.custom_time = self.custom.elapsed + return True + +class Utilization(Thread): + def __init__(self, delay=1, maxlen=20): + super().__init__() + self.cpu_mem = deque(maxlen=maxlen) + self.cpu_util = deque(maxlen=maxlen) + self.gpu_util = deque(maxlen=maxlen) + self.gpu_mem = deque(maxlen=maxlen) + + self.delay = delay + self.stopped = False + self.start() + + def run(self): + while not self.stopped: + self.cpu_util.append(100*psutil.cpu_percent()) + mem = psutil.virtual_memory() + self.cpu_mem.append(100*mem.active/mem.total) + if torch.cuda.is_available(): + self.gpu_util.append(torch.cuda.utilization()) + free, total = torch.cuda.mem_get_info() + self.gpu_mem.append(100*free/total) + else: + self.gpu_util.append(0) + self.gpu_mem.append(0) + time.sleep(self.delay) + + def stop(self): + self.stopped = True + ROUND_OPEN = rich.box.Box( "╭──╮\n" "│ │\n" diff --git a/pufferlib.cu b/pufferlib.cu index 917dd5eb4..db62686ec 100644 --- a/pufferlib.cu +++ b/pufferlib.cu @@ -103,6 +103,30 @@ void compute_p3o(torch::Tensor reward_block, torch::Tensor reward_mask, torch::Tensor dones, torch::Tensor rewards, torch::Tensor advantages, torch::Tensor bounds, int num_steps, float vstd_max, float puf, int horizon) { + + // TODO: Port from python + /* + assert all(t.is_cuda for t in [reward_block, reward_mask, values_mean, values_std, + buf, dones, rewards, advantages, bounds]), "All tensors must be on GPU" + + # Ensure contiguous memory + tensors = [reward_block, reward_mask, values_mean, values_std, buf, dones, rewards, advantages, bounds] + for t in tensors: + t.contiguous() + assert t.is_cuda + + num_steps = rewards.shape[0] + + # Precompute vstd_min and vstd_max + #vstd_max = values_std.max().item() + #vstd_min = values_std.min().item() + + # Launch kernel + threads_per_block = 256 + assert num_steps % threads_per_block == 0 + blocks = (num_steps + threads_per_block - 1) // threads_per_block + */ + // Launch the kernel int threads_per_block = 256; int blocks = (num_steps + threads_per_block - 1) / threads_per_block; diff --git a/pufferlib/models.py b/pufferlib/models.py index b7e9a873e..aa22121c7 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -206,8 +206,8 @@ def forward_train(self, observations, state): hidden, (lstm_h, lstm_c) = super().forward(hidden, lstm_state) hidden = hidden.transpose(0, 1) - hidden = hidden.reshape(B*TT, self.hidden_size) - logits, values = self.policy.decode_actions(hidden) + flat_hidden = hidden.reshape(B*TT, self.hidden_size) + logits, values = self.policy.decode_actions(flat_hidden) values = values.reshape(B, TT) state.hidden = hidden state.lstm_h = lstm_h.detach() diff --git a/pufferlib/utils.py b/pufferlib/utils.py index addb7e6b7..931bd98fd 100644 --- a/pufferlib/utils.py +++ b/pufferlib/utils.py @@ -244,8 +244,10 @@ def format_bytes(size): else: return f'{size} B' +# TODO: 5% perf gain by doing cuda sync less frequently class Profiler: - def __init__(self, elapsed=True, calls=True, memory=False, pytorch_memory=False, sync_cuda=True, amp_context=nullcontext()): + def __init__(self, elapsed=True, calls=True, memory=False, + pytorch_memory=False, sync_cuda=True, amp_context=nullcontext()): self.elapsed = 0 if elapsed else None self.calls = 0 if calls else None self.memory = None @@ -296,6 +298,9 @@ def __enter__(self): return self def __exit__(self, *args): + self.amp_context.__exit__(None, None, None) + if self.sync_cuda: + self.torch.cuda.synchronize() if self.track_elapsed: self.end_time = time.perf_counter() self.elapsed += self.end_time - self.start_time @@ -307,9 +312,6 @@ def __exit__(self, *args): if self.track_pytorch_memory: self.end_torch_mem = self.torch.cuda.memory_allocated() self.pytorch_memory = self.end_torch_mem - self.start_torch_mem - self.amp_context.__exit__(None, None, None) - if self.sync_cuda: - self.torch.cuda.synchronize() def __repr__(self): parts = [] From 8dc51a4f5cc049070153aa17332150de4d6ba5fb Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Thu, 10 Apr 2025 01:14:33 +0000 Subject: [PATCH 09/26] Remove e3b, start diayn experiments --- clean_pufferl.py | 94 +++++++++------------------------- pufferlib/models.py | 4 +- pufferlib/ocean/snake/snake.py | 2 +- pufferlib/ocean/torch.py | 28 ++++++++-- 4 files changed, 50 insertions(+), 78 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 6289244cc..84b644ccf 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -79,20 +79,12 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): free_idx = total_agents assert free_idx <= experience_rows - e3b_inv = None - e3b_orig = None - if config.use_e3b: - e3b_inv = torch.eye(policy.hidden_size).repeat(total_agents, 1, 1).to(config.device) / config.e3b_lambda - e3b_orig = e3b_inv.clone() - - diayn_archive = None + diayn_skills = None if config.use_diayn: - # TODO: Check shapes - diayn_archive = torch.nn.functional.one_hot( - torch.arange(config.diayn_archive), config.diayn_archive).to(config.device).float() - experience.diayn_skills = torch.randint( - 0, config.diayn_archive, (experience_rows,), dtype=torch.long, device=config.device) - experience.diayn_batch = torch.zeros(experience_rows, dtype=torch.long, device=config.device) + diayn_skills = torch.randint( + 0, config.diayn_archive, (total_agents,), dtype=torch.long, device=config.device) + experience.diayn_batch = torch.zeros(experience_rows, config.bptt_horizon, + dtype=torch.long, device=config.device) if config.use_p3o: batch_size = config.batch_size @@ -189,9 +181,6 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): utilization=utilization, use_p3o=config.use_p3o, p3o_horizon=config.p3o_horizon, - use_e3b=config.use_e3b, - e3b_coef=config.e3b_coef, - e3b_norm=config.e3b_norm, puf=config.puf, use_diayn=config.use_diayn, diayn_coef=config.diayn_coef, @@ -211,11 +200,8 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): device=config.device, minibatch_size=minibatch_size, compute_gae=compute_gae, - e3b_inv=e3b_inv, - e3b_orig=e3b_orig, - e3b_mean=None, - e3b_std=None, - diayn_archive=diayn_archive, + diayn_skills=diayn_skills, + total_agents=total_agents, ) @pufferlib.utils.profile @@ -229,7 +215,7 @@ def evaluate(data): lstm_h = data.lstm_h lstm_c = data.lstm_c - while data.free_idx < data.on_policy_rows: + while data.free_idx <= data.on_policy_rows: with profile.env: o, r, d, t, info, env_id, mask = data.vecenv.recv() @@ -244,17 +230,7 @@ def evaluate(data): done_mask = d + t data.global_step += mask.sum() - if data.use_diayn: - idxs = env_id[done_mask] - if len(idxs) > 0: - z_idxs = torch.randint(0, data.diayn_archive.shape[0], (done_mask.sum(),)).to(config.device) - experience.diayn_skills[idxs] = z_idxs - with profile.eval_copy: - if data.use_e3b and done_mask.any(): - done_idxs = env_id[done_mask] - data.e3b_inv[done_idxs] = data.e3b_orig[done_idxs] - o = torch.as_tensor(o) o_device = o.to(config.device, non_blocking=True) r = torch.as_tensor(r).to(config.device, non_blocking=True) @@ -277,43 +253,16 @@ def evaluate(data): ) if data.use_diayn: - z_idxs = experience.diayn_skills[env_id] - z = data.diayn_archive[z_idxs] - state.diayn_z_idxs = z_idxs - state.diayn_z = z + state.diayn_z = data.diayn_skills[env_id] logits, value = policy(o_device, state) action, logprob, _ = pufferlib.pytorch.sample_logits(logits, is_continuous=policy.is_continuous) if data.use_diayn: diayn_policy = policy if lstm_h is None else policy.policy - q = diayn_policy.diayn_discriminator(state.hidden).squeeze() - r_diayn = torch.log_softmax(q, dim=-1).gather(-1, z_idxs.unsqueeze(-1)).squeeze() + q = diayn_policy.diayn_discriminator(logits).squeeze() + r_diayn = torch.log_softmax(q, dim=-1).gather(-1, state.diayn_z.unsqueeze(-1)).squeeze() r += config.diayn_coef*r_diayn# - np.log(1/data.diayn_archive) - state.diayn_z = z - state.diayn_z_idxs = z_idxs - - if data.use_e3b: - e3b = data.e3b_inv[env_id] - phi = state.hidden.detach() - u = phi.unsqueeze(1) @ e3b - b = u @ phi.unsqueeze(2) - data.e3b_inv[env_id] -= (u.mT @ u) / (1 + b) - done_inds = env_id[done_mask] - data.e3b_inv[done_inds] = data.e3b_orig[done_inds] - e3b_reward = b.squeeze() - - if data.e3b_mean is None: - data.e3b_mean = e3b_reward.mean() - data.e3b_std = e3b_reward.std() - else: - w = data.e3b_norm - data.e3b_mean = (1-w)*e3b_reward.mean() + w*data.e3b_mean - data.e3b_std = (1-w)*e3b_reward.std() + w*data.e3b_std - - e3b_reward = (e3b_reward - data.e3b_mean) / (data.e3b_std + 1e-6) - e3b_reward = config.e3b_coef*e3b_reward - r += e3b_reward # Clip rewards r = torch.clamp(r, -1, 1) @@ -354,6 +303,7 @@ def evaluate(data): data.stats[k] += v data.free_idx = 0 + data.ep_indices = torch.arange(data.total_agents, device=config.device, dtype=torch.int32) data.ep_lengths.zero_() data.ep_uses.zero_() return data.stats, infos @@ -417,7 +367,7 @@ def train(data): ) if config.use_diayn: - state.z_idxs = batch.diayn_z_idxs + state.diayn_z = batch.diayn_z.reshape(-1) with profile.train_forward: if not isinstance(data.policy, torch.nn.LSTM): @@ -488,8 +438,9 @@ def train(data): if config.use_diayn: diayn_discriminator = (data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator) - q = diayn_discriminator(state.hidden).squeeze() - z_idxs = state.z_idxs.unsqueeze(1).expand(q.shape[:2]).reshape(-1) + #noise = torch.randn_like(logits) + q = diayn_discriminator(logits).squeeze() + z_idxs = state.diayn_z q = q.view(-1, q.shape[-1]) diayn_loss = torch.nn.functional.cross_entropy(q, z_idxs) loss += config.diayn_loss_coef*diayn_loss @@ -607,7 +558,11 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): exp.values[batch_rows, l] = value.flatten() if data.use_diayn: - exp.diayn_batch[batch_rows] = state.diayn_z_idxs + exp.diayn_batch[batch_rows, l] = state.diayn_z + idxs = env_id[done] + if len(idxs) > 0: + z_idxs = torch.randint(0, data.config.diayn_archive, (done.sum(),)).to(data.device) + data.diayn_skills[idxs] = z_idxs # TODO: Handle masks!! #indices = np.where(mask)[0] @@ -650,8 +605,7 @@ def sample(data, advantages, n, reward_block=None, mask_block=None, method='mult output['returns'] = advantages[idx] + exp.values[idx] if data.use_diayn: - output['diayn_z_idxs'] = exp.diayn_batch[idx] - output['diayn_z'] = exp.diayn_skills[idx] + output['diayn_z'] = exp.diayn_batch[idx] return pufferlib.namespace(**output) @@ -779,9 +733,6 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k if model_path is not None: agent.load_state_dict(torch.load(model_path, map_location=device, weights_only=False)) - #e3b_inv = 10*torch.eye(agent.hidden_size).repeat(env_kwargs['num_envs'], 1, 1).to(device) - e3b_inv = None - ob, info = env.reset() driver = env.driver_env os.system('clear') @@ -789,6 +740,7 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k state = pufferlib.namespace( lstm_h=None, lstm_c=None, + diayn_z=torch.ones(env.num_agents, dtype=torch.long, device=device), ) num_agents = env.observation_space.shape[0] diff --git a/pufferlib/models.py b/pufferlib/models.py index aa22121c7..bc9055aac 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -154,7 +154,7 @@ def __init__(self, env, policy, input_size=128, hidden_size=128): def forward(self, observations, state): '''Forward function for inference. 3x faster than using LSTM directly''' - hidden = self.policy.encode_observations(observations) + hidden = self.policy.encode_observations(observations, state) h = state.lstm_h c = state.lstm_c @@ -197,7 +197,7 @@ def forward_train(self, observations, state): lstm_state = None x = x.reshape(B*TT, *space_shape) - hidden = self.policy.encode_observations(x) + hidden = self.policy.encode_observations(x, state) assert hidden.shape == (B*TT, self.input_size) hidden = hidden.reshape(B, TT, self.input_size) diff --git a/pufferlib/ocean/snake/snake.py b/pufferlib/ocean/snake/snake.py index 465a4fe53..00888f986 100644 --- a/pufferlib/ocean/snake/snake.py +++ b/pufferlib/ocean/snake/snake.py @@ -13,7 +13,7 @@ def __init__(self, num_envs=16, width=640, height=360, vision=5, leave_corpse_on_death=True, reward_food=0.1, reward_corpse=0.1, reward_death=-1.0, report_interval=128, max_snake_length=1024, - render_mode='human', buf=None): + render_mode='human', buf=None, seed=0): if num_envs is not None: num_snakes = num_envs * [num_snakes] diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 18f01b2e2..4988e0e03 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -88,10 +88,22 @@ def decode_actions(self, flat_hidden): return action, value class Snake(nn.Module): - def __init__(self, env, cnn_channels=32, hidden_size=128, use_p3o=False, p3o_horizon=32, **kwargs): + def __init__(self, env, cnn_channels=32, hidden_size=128, + use_p3o=False, p3o_horizon=32, use_diayn=False, diayn_skills=8, **kwargs): super().__init__() self.hidden_size = hidden_size self.is_continuous = False + self.use_diayn = use_diayn + + encode_dim = cnn_channels + if use_diayn: + encode_dim += diayn_skills + self.diayn_skills = diayn_skills + self.diayn_discriminator = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(env.single_action_space.n, hidden_size)), + nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(hidden_size, diayn_skills)), + ) self.network= nn.Sequential( pufferlib.pytorch.layer_init( @@ -101,7 +113,9 @@ def __init__(self, env, cnn_channels=32, hidden_size=128, use_p3o=False, p3o_hor nn.Conv2d(cnn_channels, cnn_channels, 3, stride=1)), nn.ReLU(), nn.Flatten(), - pufferlib.pytorch.layer_init(nn.Linear(cnn_channels, hidden_size)), + ) + self.proj = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(encode_dim, hidden_size)), nn.ReLU(), ) self.actor = pufferlib.pytorch.layer_init( @@ -122,9 +136,15 @@ def forward(self, observations): actions, value = self.decode_actions(hidden, lookup) return (actions, value), hidden - def encode_observations(self, observations): + def encode_observations(self, observations, state=None): observations = F.one_hot(observations.long(), 8).permute(0, 3, 1, 2).float() - return self.network(observations) + hidden = self.network(observations) + + if self.use_diayn: + z_one_hot = F.one_hot(state.diayn_z, self.diayn_skills).float() + hidden = torch.cat([hidden, z_one_hot], dim=1) + + return self.proj(hidden) def decode_actions(self, hidden): action = self.actor(hidden) From 7ca9e317fcc567aa67e68901662a3817fd097fd6 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Thu, 10 Apr 2025 01:14:40 +0000 Subject: [PATCH 10/26] default --- config/default.ini | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/default.ini b/config/default.ini index 1290ddf32..8ff14dc9f 100644 --- a/config/default.ini +++ b/config/default.ini @@ -48,7 +48,7 @@ data_dir = experiments checkpoint_interval = 200 batch_size = 524288 minibatch_size = 8192 -replay_factor = 0.125 +replay_factor = 0.0 # Accumulate gradients above this size max_minibatch_size = 16384 bptt_horizon = 64 @@ -62,9 +62,9 @@ e3b_norm = 0.001 e3b_lambda = 10.0 use_diayn = False -diayn_archive = 8 -diayn_loss_coef = 1.0 -diayn_coef = 0.1 +diayn_archive = 4 +diayn_loss_coef = 0.1 +diayn_coef = 0.0 use_p3o = False p3o_horizon = 128 From 53aff6b6b393e8cea9edeef4d97249d5dc2121ba Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Thu, 10 Apr 2025 17:14:57 +0000 Subject: [PATCH 11/26] diayn working over atn segments --- clean_pufferl.py | 20 ++++++++++++++++---- config/default.ini | 2 +- pufferlib/models.py | 1 + pufferlib/ocean/torch.py | 3 ++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 84b644ccf..6e2591d81 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -258,11 +258,13 @@ def evaluate(data): logits, value = policy(o_device, state) action, logprob, _ = pufferlib.pytorch.sample_logits(logits, is_continuous=policy.is_continuous) + ''' if data.use_diayn: diayn_policy = policy if lstm_h is None else policy.policy q = diayn_policy.diayn_discriminator(logits).squeeze() r_diayn = torch.log_softmax(q, dim=-1).gather(-1, state.diayn_z.unsqueeze(-1)).squeeze() r += config.diayn_coef*r_diayn# - np.log(1/data.diayn_archive) + ''' # Clip rewards r = torch.clamp(r, -1, 1) @@ -439,8 +441,18 @@ def train(data): diayn_discriminator = (data.policy.diayn_discriminator if hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator) #noise = torch.randn_like(logits) - q = diayn_discriminator(logits).squeeze() - z_idxs = state.diayn_z + batch_logits = state.batch_logits + mmax = batch_logits.max(dim=-1, keepdim=True)[0] + mmin = batch_logits.min(dim=-1, keepdim=True)[0] + batch_logits = (batch_logits - mmin) / (mmax - mmin + 1e-6) + mask = torch.nn.functional.one_hot(batch.actions, 4) + #inds = torch.randint(0, config.bptt_horizon-4, (mask.shape[0],)).to(mask.device) + #mask[:, :32] = 0 + #mask[:, 40:] = 0 + #batch_logits = batch_logits * mask + batch_logits = batch_logits.view(state.batch_logits.shape[0], -1) + q = diayn_discriminator(batch_logits).squeeze() + z_idxs = batch.diayn_z[:, 0] q = q.view(-1, q.shape[-1]) diayn_loss = torch.nn.functional.cross_entropy(q, z_idxs) loss += config.diayn_loss_coef*diayn_loss @@ -740,7 +752,7 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k state = pufferlib.namespace( lstm_h=None, lstm_c=None, - diayn_z=torch.ones(env.num_agents, dtype=torch.long, device=device), + diayn_z=torch.arange(env.num_agents, dtype=torch.long, device=device) % 4 ) num_agents = env.observation_space.shape[0] @@ -756,7 +768,7 @@ def rollout(env_creator, env_kwargs, policy_cls, rnn_cls, agent_creator, agent_k intrinsic_mean = None intrinsic_std = None while tick <= 200000: - if tick % 1 == 0: + if tick > 1000 and tick % 1 == 0: #render = driver.render(overlay=float(intrinsic[0])) render = driver.render() if driver.render_mode == 'ansi': diff --git a/config/default.ini b/config/default.ini index 8ff14dc9f..b38b1f22a 100644 --- a/config/default.ini +++ b/config/default.ini @@ -63,7 +63,7 @@ e3b_lambda = 10.0 use_diayn = False diayn_archive = 4 -diayn_loss_coef = 0.1 +diayn_loss_coef = 1.0 diayn_coef = 0.0 use_p3o = False diff --git a/pufferlib/models.py b/pufferlib/models.py index bc9055aac..3174bde58 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -209,6 +209,7 @@ def forward_train(self, observations, state): flat_hidden = hidden.reshape(B*TT, self.hidden_size) logits, values = self.policy.decode_actions(flat_hidden) values = values.reshape(B, TT) + state.batch_logits = logits.reshape(B, TT, -1) state.hidden = hidden state.lstm_h = lstm_h.detach() state.lstm_c = lstm_c.detach() diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 4988e0e03..48ff94cb0 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -100,7 +100,8 @@ def __init__(self, env, cnn_channels=32, hidden_size=128, encode_dim += diayn_skills self.diayn_skills = diayn_skills self.diayn_discriminator = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(env.single_action_space.n, hidden_size)), + #nn.Dropout(0.5), + pufferlib.pytorch.layer_init(nn.Linear(64*env.single_action_space.n, hidden_size)), nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(hidden_size, diayn_skills)), ) From eca609c598f9abd5754fbb04093cfeda2ccd2289 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Thu, 10 Apr 2025 18:08:17 +0000 Subject: [PATCH 12/26] traj based diayn --- clean_pufferl.py | 66 ++++++++++++++++++++++++---------------- pufferlib/ocean/torch.py | 17 ++++++++--- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 6e2591d81..8c36053e2 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -317,6 +317,17 @@ def train(data): experience = data.experience losses = data.losses + with profile.custom: + if config.use_diayn: + diayn_policy = data.policy.policy + obs = experience.obs[:, ::8] + q = diayn_policy.discrim_forward(obs) + z_idxs = experience.diayn_batch[:, 0] + q = q.view(-1, q.shape[-1]) + diayn_r = (torch.argmax(q, 1) == z_idxs).float() + experience.rewards[:, -1] += 1.0*diayn_r + print('DIAYN acc: ', diayn_r.mean()) + total_minibatches = int(config.update_epochs*config.batch_size/data.minibatch_size) accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) n_samples = data.minibatch_size // config.bptt_horizon @@ -361,6 +372,30 @@ def train(data): with profile.train_copy: batch = sample(data, advantages, n_samples) + loss = 0 + with profile.custom: + if config.use_diayn: + diayn_policy = data.policy.policy + obs = batch.obs[:, ::8] + q = diayn_policy.discrim_forward(obs) + z_idxs = batch.diayn_z[:, 0] + q = q.view(-1, q.shape[-1]) + diayn_loss = torch.nn.functional.cross_entropy(q, z_idxs) + loss += config.diayn_loss_coef*diayn_loss + ''' + with torch.no_grad(): + batch.advantages *= diayn_r.unsqueeze(1).expand_as(batch.advantages) + ''' + + ''' + rewards = experience.rewards.clone() + rewards[batch.idx, -1] += diayn_r + advantages = data.compute_gae(experience.values, rewards, + experience.dones, config.gamma, config.gae_lambda) + batch.advantages = advantages[batch.idx] + ''' + + with profile.train_misc: state = pufferlib.namespace( action=batch.actions, @@ -434,28 +469,7 @@ def train(data): v_loss = 0.5 * ((newvalue - ret) ** 2).mean() entropy_loss = entropy.mean() - loss = pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef - - with profile.custom: - if config.use_diayn: - diayn_discriminator = (data.policy.diayn_discriminator if - hasattr(data.policy, 'diayn_discriminator') else data.policy.policy.diayn_discriminator) - #noise = torch.randn_like(logits) - batch_logits = state.batch_logits - mmax = batch_logits.max(dim=-1, keepdim=True)[0] - mmin = batch_logits.min(dim=-1, keepdim=True)[0] - batch_logits = (batch_logits - mmin) / (mmax - mmin + 1e-6) - mask = torch.nn.functional.one_hot(batch.actions, 4) - #inds = torch.randint(0, config.bptt_horizon-4, (mask.shape[0],)).to(mask.device) - #mask[:, :32] = 0 - #mask[:, 40:] = 0 - #batch_logits = batch_logits * mask - batch_logits = batch_logits.view(state.batch_logits.shape[0], -1) - q = diayn_discriminator(batch_logits).squeeze() - z_idxs = batch.diayn_z[:, 0] - q = q.view(-1, q.shape[-1]) - diayn_loss = torch.nn.functional.cross_entropy(q, z_idxs) - loss += config.diayn_loss_coef*diayn_loss + loss += pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef with profile.learn: if data.scaler is not None: @@ -571,10 +585,10 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): if data.use_diayn: exp.diayn_batch[batch_rows, l] = state.diayn_z - idxs = env_id[done] - if len(idxs) > 0: - z_idxs = torch.randint(0, data.config.diayn_archive, (done.sum(),)).to(data.device) - data.diayn_skills[idxs] = z_idxs + #idxs = env_id[done] + #if len(idxs) > 0: + # z_idxs = torch.randint(0, data.config.diayn_archive, (done.sum(),)).to(data.device) + # data.diayn_skills[idxs] = z_idxs # TODO: Handle masks!! #indices = np.where(mask)[0] diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 48ff94cb0..16b65b8a3 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -100,12 +100,15 @@ def __init__(self, env, cnn_channels=32, hidden_size=128, encode_dim += diayn_skills self.diayn_skills = diayn_skills self.diayn_discriminator = nn.Sequential( - #nn.Dropout(0.5), - pufferlib.pytorch.layer_init(nn.Linear(64*env.single_action_space.n, hidden_size)), + pufferlib.pytorch.layer_init( + nn.Conv2d(64, cnn_channels, 5, stride=3)), + nn.ReLU(), + pufferlib.pytorch.layer_init( + nn.Conv2d(cnn_channels, cnn_channels, 3, stride=1)), nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Linear(hidden_size, diayn_skills)), + nn.Flatten(), + pufferlib.pytorch.layer_init(nn.Linear(cnn_channels, diayn_skills)), ) - self.network= nn.Sequential( pufferlib.pytorch.layer_init( nn.Conv2d(8, cnn_channels, 5, stride=3)), @@ -132,6 +135,12 @@ def __init__(self, env, cnn_channels=32, hidden_size=128, self.value = pufferlib.pytorch.layer_init( nn.Linear(hidden_size, 1), std=1) + def discrim_forward(self, obs): + obs = F.one_hot(obs.long(), 8).permute(0, 1, 4, 2, 3).float() + B, f, c, h, w = obs.shape + obs = obs.reshape(B, f*c, h, w) + return self.diayn_discriminator(obs) + def forward(self, observations): hidden, lookup = self.encode_observations(observations) actions, value = self.decode_actions(hidden, lookup) From 29fab3c324fc4dbafee85c608337dcbd712ae85c Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Thu, 10 Apr 2025 23:03:19 +0000 Subject: [PATCH 13/26] Grid C bind --- pufferlib/models.py | 6 +- pufferlib/ocean/env_binding.h | 9 +-- pufferlib/ocean/grid/binding.c | 67 +++++++++++++++++++ pufferlib/ocean/grid/grid.c | 25 ++++--- pufferlib/ocean/grid/grid.h | 117 ++++++++++++++------------------- pufferlib/ocean/grid/grid.py | 21 +++--- pufferlib/ocean/torch.py | 9 +-- setup.py | 10 +-- 8 files changed, 159 insertions(+), 105 deletions(-) create mode 100644 pufferlib/ocean/grid/binding.c diff --git a/pufferlib/models.py b/pufferlib/models.py index 3174bde58..085ff7f7c 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -82,7 +82,7 @@ def __init__(self, env, hidden_size=128, use_p3o=False, p3o_horizon=32, use_diay nn.Linear(hidden_size, 1), std=1) def forward(self, observations, state=None): - hidden = self.encode_observations(observations) + hidden = self.encode_observations(observations, state=state) state.hidden = hidden logits, values = self.decode_actions(hidden) return logits, values @@ -90,7 +90,7 @@ def forward(self, observations, state=None): def forward_train(self, observations, state=None): return self.forward(observations, state) - def encode_observations(self, observations): + def encode_observations(self, observations, state=None): '''Encodes a batch of observations into hidden states. Assumes no time dimension (handled by LSTM wrappers).''' batch_size = observations.shape[0] @@ -154,7 +154,7 @@ def __init__(self, env, policy, input_size=128, hidden_size=128): def forward(self, observations, state): '''Forward function for inference. 3x faster than using LSTM directly''' - hidden = self.policy.encode_observations(observations, state) + hidden = self.policy.encode_observations(observations, state=state) h = state.lstm_h c = state.lstm_c diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 0e396d1a0..f57bd1d89 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -2,9 +2,9 @@ #include // Forward declarations for env-specific functions supplied by user -static int my_init(Env* env, PyObject* args, PyObject* kwargs); -//typedef struct Log Log; static int my_log(PyObject* dict, Log* log); +static int my_init(Env* env, PyObject* args, PyObject* kwargs); +static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs); static Env* unpack_env(PyObject* args) { PyObject* handle_obj = PyTuple_GetItem(args, 0); @@ -144,7 +144,7 @@ static PyObject* env_step(PyObject* self, PyObject* args) { if (!env){ return NULL; } - step(env); + c_step(env); Py_RETURN_NONE; } @@ -384,7 +384,7 @@ static PyObject* vec_step(PyObject* self, PyObject* arg) { } for (int i = 0; i < vec->num_envs; i++) { - step(vec->envs[i]); + c_step(vec->envs[i]); } Py_RETURN_NONE; } @@ -513,6 +513,7 @@ static PyMethodDef methods[] = { {"vec_log", vec_log, METH_VARARGS, "Log the vector of environments"}, {"vec_render", vec_render, METH_VARARGS, "Render the vector of environments"}, {"vec_close", vec_close, METH_VARARGS, "Close the vector of environments"}, + {"shared", (PyCFunction)my_shared, METH_VARARGS | METH_KEYWORDS, "Shared state"}, {NULL, NULL, 0, NULL} }; diff --git a/pufferlib/ocean/grid/binding.c b/pufferlib/ocean/grid/binding.c new file mode 100644 index 000000000..27d900dc8 --- /dev/null +++ b/pufferlib/ocean/grid/binding.c @@ -0,0 +1,67 @@ +#include "grid.h" + +#define Env Grid +#include "../env_binding.h" + +static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { + int num_maps = unpack(kwargs, "num_maps"); + int max_size = unpack(kwargs, "max_size"); + int size = unpack(kwargs, "size"); + State* levels = calloc(num_maps, sizeof(State)); + + if (max_size <= 5) { + PyErr_SetString(PyExc_ValueError, "max_size must be >5"); + return NULL; + } + + // Temporary env used to gen maps + Grid env; + env.max_size = max_size; + init_grid(&env); + + for (int i = 0; i < num_maps; i++) { + int sz = size; + if (size == -1) { + sz = 5 + (rand() % (max_size-5)); + } + + if (sz % 2 == 0) { + sz -= 1; + } + + float difficulty = (float)rand()/(float)(RAND_MAX); + create_maze_level(&env, sz, sz, difficulty, i); + init_state(&levels[i], max_size, 1); + get_state(&env, &levels[i]); + } + + return PyLong_FromVoidPtr(levels); +} + +static int my_init(Env* env, PyObject* args, PyObject* kwargs) { + env->max_size = unpack(kwargs, "max_size"); + env->num_maps = unpack(kwargs, "num_maps"); + init_grid(env); + + PyObject* handle_obj = PyDict_GetItemString(kwargs, "state"); + if (!PyObject_TypeCheck(handle_obj, &PyLong_Type)) { + PyErr_SetString(PyExc_TypeError, "state handle must be an integer"); + return 1; + } + + State* levels = (State*)PyLong_AsVoidPtr(handle_obj); + if (!levels) { + PyErr_SetString(PyExc_ValueError, "Invalid state handle"); + return 1; + } + + env->levels = levels; + return 0; +} + +static int my_log(PyObject* dict, Log* log) { + assign_to_dict(dict, "episode_return", log->episode_return); + assign_to_dict(dict, "episode_length", log->episode_length); + assign_to_dict(dict, "score", log->score); + return 0; +} diff --git a/pufferlib/ocean/grid/grid.c b/pufferlib/ocean/grid/grid.c index fd77ff00d..572b06a32 100644 --- a/pufferlib/ocean/grid/grid.c +++ b/pufferlib/ocean/grid/grid.c @@ -22,8 +22,15 @@ int main() { //env->agents[0].color = 6; //reset(env, seed); //load_locked_room_preset(env); + - create_maze_level(env, 5, 5, 0.85, seed); + State* levels = calloc(1, sizeof(State)); + + create_maze_level(env, 11, 11, 0.85, seed); + init_state(levels, max_size, num_agents); + get_state(env, levels); + env->num_maps = 1; + env->levels = levels; //generate_locked_room(env); //State state; //init_state(&state, env->max_size, env->num_agents); @@ -40,9 +47,8 @@ int main() { env->grid[(env->height-2)*env->max_size + (env->width - 2)] = GOAL; */ - Renderer* renderer = init_renderer(render_cell_size, width, height); - int tick = 0; + render(env); while (!WindowShouldClose()) { // User can take control of the first agent env->actions[0] = ATN_FORWARD; @@ -74,20 +80,13 @@ int main() { //env->actions[0] = actions[t]; tick = (tick + 1)%12; bool done = false; - if (tick % 12 == 0) { - done = step(env); + if (tick % 1 == 0) { + c_step(env); printf("direction: %f\n", env->agents[0].direction); } - if (done) { - printf("Done, reward: %f\n", env->rewards[0]); - seed++; - reset(env, seed); - create_maze_level(env, 5, 5, 0.85, seed); - } - render_global(renderer, env, (float)tick/12.0); + render(env); } - close_renderer(renderer); free_allocated_grid(env); return 0; } diff --git a/pufferlib/ocean/grid/grid.h b/pufferlib/ocean/grid/grid.h index bdfbf3624..e31ff8f17 100644 --- a/pufferlib/ocean/grid/grid.h +++ b/pufferlib/ocean/grid/grid.h @@ -38,51 +38,9 @@ struct Log { float episode_return; float episode_length; float score; + float n; }; -typedef struct LogBuffer LogBuffer; -struct LogBuffer { - Log* logs; - int length; - int idx; -}; - -LogBuffer* allocate_logbuffer(int size) { - LogBuffer* logs = (LogBuffer*)calloc(1, sizeof(LogBuffer)); - logs->logs = (Log*)calloc(size, sizeof(Log)); - logs->length = size; - logs->idx = 0; - return logs; -} - -void free_logbuffer(LogBuffer* buffer) { - free(buffer->logs); - free(buffer); -} - -void add_log(LogBuffer* logs, Log* log) { - if (logs->idx == logs->length) { - return; - } - logs->logs[logs->idx] = *log; - logs->idx += 1; - //printf("Log: %f, %f, %f\n", log->episode_return, log->episode_length, log->score); -} - -Log aggregate_and_clear(LogBuffer* logs) { - Log log = {0}; - if (logs->idx == 0) { - return log; - } - for (int i = 0; i < logs->idx; i++) { - log.episode_return += logs->logs[i].episode_return / logs->idx; - log.episode_length += logs->logs[i].episode_length / logs->idx; - log.score += logs->logs[i].score / logs->idx; - } - logs->idx = 0; - return log; -} - // 8 unique agents bool is_agent(int idx) { return idx >= AGENT && idx < AGENT + 8; @@ -118,29 +76,38 @@ struct Agent { int held; }; +typedef struct Renderer Renderer; +typedef struct State State; typedef struct Grid Grid; struct Grid{ + Renderer* renderer; + State* levels; + int num_maps; int width; int height; int num_agents; int horizon; int vision; + int tick; float speed; int obs_size; int max_size; bool discretize; Log log; - LogBuffer* log_buffer; Agent* agents; unsigned char* grid; int* counts; unsigned char* observations; float* actions; float* rewards; - unsigned char* dones; + unsigned char* terminals; }; void init_grid(Grid* env) { + env->num_agents = 1; + env->vision = 5; + env->speed = 1; + env->discretize = true; env->obs_size = 2*env->vision + 1; int env_mem= env->max_size * env->max_size; env->grid = calloc(env_mem, sizeof(unsigned char)); @@ -162,8 +129,7 @@ Grid* allocate_grid(int max_size, int num_agents, int horizon, num_agents*obs_size*obs_size, sizeof(unsigned char)); env->actions = calloc(num_agents, sizeof(float)); env->rewards = calloc(num_agents, sizeof(float)); - env->dones = calloc(num_agents, sizeof(unsigned char)); - env->log_buffer = allocate_logbuffer(LOG_BUFFER_SIZE); + env->terminals = calloc(num_agents, sizeof(unsigned char)); init_grid(env); return env; } @@ -178,8 +144,7 @@ void free_allocated_grid(Grid* env) { free(env->observations); free(env->actions); free(env->rewards); - free(env->dones); - free_logbuffer(env->log_buffer); + free(env->terminals); free_env(env); } @@ -192,6 +157,13 @@ int grid_offset(Grid* env, int y, int x) { return y*env->max_size + x; } +void add_log(Grid* env, int idx) { + env->log.episode_return += env->rewards[idx]; + env->log.score += env->rewards[idx]; + env->log.episode_length += env->tick; + env->log.n += 1.0; +} + void compute_observations(Grid* env) { memset(env->observations, 0, env->obs_size*env->obs_size*env->num_agents); for (int agent_idx = 0; agent_idx < env->num_agents; agent_idx++) { @@ -275,7 +247,6 @@ void spawn_agent(Grid* env, int idx, int x, int y) { agent->color = AGENT; } -typedef struct State State; struct State { int width; int height; @@ -312,10 +283,13 @@ void set_state(Grid* env, State* state) { memcpy(env->grid, state->grid, env->max_size*env->max_size); } -void reset(Grid* env, int seed) { - env->log = (Log){0}; +void reset(Grid* env) { memset(env->grid, 0, env->max_size*env->max_size); memset(env->counts, 0, env->max_size*env->max_size*sizeof(int)); + env->tick = 0; + int idx = rand() % env->num_maps; + set_state(env, &env->levels[idx]); + compute_observations(env); } int move_to(Grid* env, int agent_idx, float y, float x) { @@ -330,9 +304,8 @@ int move_to(Grid* env, int agent_idx, float y, float x) { return 1; } else if (dest == REWARD || dest == GOAL) { env->rewards[agent_idx] = 1.0; - env->dones[agent_idx] = 1; - env->log.episode_return += 1.0; - env->log.score += 1.0; + env->terminals[agent_idx] = 1; + add_log(env, agent_idx); } else if (is_key(dest)) { if (agent->held != -1) { return 1; @@ -428,9 +401,10 @@ bool step_agent(Grid* env, int idx) { return true; } -bool step(Grid* env) { - memset(env->dones, 0, env->num_agents); +void c_step(Grid* env) { + memset(env->terminals, 0, env->num_agents); memset(env->rewards, 0, env->num_agents*sizeof(float)); + env->tick++; for (int i = 0; i < env->num_agents; i++) { step_agent(env, i); @@ -439,21 +413,23 @@ bool step(Grid* env) { bool done = true; for (int i = 0; i < env->num_agents; i++) { - if (!env->dones[i]) { + if (!env->terminals[i]) { done = false; break; } } - env->log.episode_length += 1; - if (env->log.episode_length >= env->horizon) { + if (env->tick >= env->horizon) { done = true; + add_log(env, 0); } if (done) { - add_log(env->log_buffer, &env->log); + reset(env); + int idx = rand() % env->num_maps; + set_state(env, &env->levels[idx]); + compute_observations(env); } - return done; } // Raylib client @@ -480,13 +456,13 @@ Rectangle UV_COORDS[7] = { (Rectangle){384, 0, 128, 128}, }; -typedef struct { +struct Renderer { int cell_size; int width; int height; Texture2D puffer; float* overlay; -} Renderer; +}; Renderer* init_renderer(int cell_size, int width, int height) { Renderer* renderer = (Renderer*)calloc(1, sizeof(Renderer)); @@ -513,7 +489,15 @@ void close_renderer(Renderer* renderer) { free(renderer); } -void render_global(Renderer* renderer, Grid* env, float frac, float overlay) { +void render(Grid* env) { + // TODO: fractional rendering + float frac = 0.0; + float overlay = 0.0; + if (env->renderer == NULL) { + env->renderer = init_renderer(16, env->width, env->height); + } + Renderer* renderer = env->renderer; + if (IsKeyDown(KEY_ESCAPE)) { exit(0); } @@ -524,7 +508,7 @@ void render_global(Renderer* renderer, Grid* env, float frac, float overlay) { int adr = grid_offset(env, r, c); //renderer->overlay[adr] = overlay; //renderer->overlay[adr] -= 0.1; - renderer->overlay[adr] = -1 + 1.0/(float)env->counts[adr]; + //renderer->overlay[adr] = -1 + 1.0/(float)env->counts[adr]; BeginDrawing(); ClearBackground((Color){6, 24, 24, 255}); @@ -535,6 +519,7 @@ void render_global(Renderer* renderer, Grid* env, float frac, float overlay) { adr = grid_offset(env, r, c); int tile = env->grid[adr]; if (tile == EMPTY) { + continue; overlay = renderer->overlay[adr]; if (overlay == 0) { continue; diff --git a/pufferlib/ocean/grid/grid.py b/pufferlib/ocean/grid/grid.py index 7fc1560ee..a4562f78e 100644 --- a/pufferlib/ocean/grid/grid.py +++ b/pufferlib/ocean/grid/grid.py @@ -4,7 +4,7 @@ import gymnasium import pufferlib -from pufferlib.ocean.grid.cy_grid import CGrid +from pufferlib.ocean.grid import binding class Grid(pufferlib.PufferEnv): def __init__(self, render_mode='raylib', vision_range=5, @@ -19,33 +19,34 @@ def __init__(self, render_mode='raylib', vision_range=5, self.report_interval = report_interval super().__init__(buf=buf) self.float_actions = np.zeros_like(self.actions).astype(np.float32) - self.c_envs = CGrid(self.observations, self.float_actions, - self.rewards, self.terminals, num_envs, num_maps, map_size, max_map_size) + self.c_state = binding.shared(num_maps=num_maps, max_size=max_map_size, size=map_size) + self.c_envs = binding.vec_init(self.observations, self.float_actions, + self.rewards, self.terminals, self.truncations, num_envs, seed, + state=self.c_state, max_size=max_map_size, num_maps=num_maps) + pass def reset(self, seed=None): self.tick = 0 - self.c_envs.reset() + binding.vec_reset(self.c_envs, seed) return self.observations, [] def step(self, actions): self.float_actions[:] = actions - self.c_envs.step() + binding.vec_step(self.c_envs) info = [] if self.tick % self.report_interval == 0: - log = self.c_envs.log() - if log['episode_length'] > 0: - info.append(log) + info.append(binding.vec_log(self.c_envs)) self.tick += 1 return (self.observations, self.rewards, self.terminals, self.truncations, info) def render(self, overlay=0): - self.c_envs.render(overlay=overlay) + binding.vec_render(self.c_envs, overlay) def close(self): - self.c_envs.close() + binding.vec_close(self.c_envs) def test_performance(timeout=10, atn_cache=1024): env = CGrid(num_envs=1000) diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 16b65b8a3..346d2d727 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -170,6 +170,7 @@ def decode_actions(self, hidden): class Grid(nn.Module): def __init__(self, env, cnn_channels=32, hidden_size=128, **kwargs): super().__init__() + self.hidden_size = hidden_size self.network = nn.Sequential( pufferlib.pytorch.layer_init( nn.Conv2d(32, cnn_channels, 5, stride=3)), @@ -196,18 +197,18 @@ def __init__(self, env, cnn_channels=32, hidden_size=128, **kwargs): self.value_fn = pufferlib.pytorch.layer_init( nn.Linear(hidden_size, 1), std=1) - def forward(self, observations): + def forward(self, observations, state=None): hidden, lookup = self.encode_observations(observations) actions, value = self.decode_actions(hidden, lookup) return actions, value - def encode_observations(self, observations): + def encode_observations(self, observations, state=None): hidden = observations.view(-1, 11, 11).long() hidden = F.one_hot(hidden, 32).permute(0, 3, 1, 2).float() hidden = self.network(hidden) - return hidden, None + return hidden - def decode_actions(self, flat_hidden): + def decode_actions(self, flat_hidden, state=None): value = self.value_fn(flat_hidden) if self.is_continuous: mean = self.decoder_mean(flat_hidden) diff --git a/setup.py b/setup.py index 2802d2be5..2a1fb78b0 100644 --- a/setup.py +++ b/setup.py @@ -270,7 +270,7 @@ 'pufferlib/ocean/enduro/cy_enduro', 'pufferlib/ocean/blastar/cy_blastar', 'pufferlib/ocean/connect4/cy_connect4', - 'pufferlib/ocean/grid/cy_grid', + #'pufferlib/ocean/grid/cy_grid', 'pufferlib/ocean/tripletriad/cy_tripletriad', 'pufferlib/ocean/go/cy_go', 'pufferlib/ocean/rware/cy_rware', @@ -300,7 +300,7 @@ path.replace('/', '.'), [path + '.pyx'], include_dirs=[numpy.get_include(), 'raylib/include'], - extra_compile_args=extra_compile_args,# + ['-fsanitize=address,undefined,bounds,pointer-overflow,leak', '-g'], + extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, extra_objects=[f'{RAYLIB_NAME}/lib/libraylib.a'], ) for path in extension_paths] @@ -309,14 +309,14 @@ #c_args = ['-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION', '-DPLATFORM_DESKTOP', '-O2'] #c_args += "-Wsign-compare -DNDEBUG -g -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC".split() -pure_c_extensions = ['squared', 'pong', 'breakout', 'nmmo3'] +pure_c_extensions = ['squared', 'grid', 'pong', 'breakout', 'nmmo3'] extensions += [ Extension( f'pufferlib.ocean.{name}.binding', sources=[f'pufferlib/ocean/{name}/binding.c'], include_dirs=[numpy.get_include(), 'raylib/include'], - extra_compile_args=extra_compile_args, - extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args,# + ['-fsanitize=address,undefined,bounds,pointer-overflow,leak'], + extra_link_args=extra_link_args,# + ['-fsanitize=address,undefined,bounds,pointer-overflow,leak', '-g'], extra_objects=[f'{RAYLIB_NAME}/lib/libraylib.a'], ) for name in pure_c_extensions From f4bc97d5e420789c11f7c1b751bece25b0dfdbcd Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Fri, 11 Apr 2025 00:44:29 +0000 Subject: [PATCH 14/26] Found the issue - not mp safe because traj lengths desync --- clean_pufferl.py | 6 +++--- config/ocean/nmmo3.ini | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 8c36053e2..29f39146c 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -73,7 +73,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): truncateds=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), ) ep_uses = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) - stored_indices = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) + #stored_indices = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) ep_lengths = torch.zeros(total_agents, device=config.device, dtype=torch.int32) ep_indices = torch.arange(total_agents, device=config.device, dtype=torch.int32) free_idx = total_agents @@ -189,7 +189,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): step=0, lstm_h=lstm_h, lstm_c=lstm_c, - stored_indices=stored_indices, + #stored_indices=stored_indices, ep_uses=ep_uses, ep_lengths=ep_lengths, ep_indices=ep_indices, @@ -569,7 +569,7 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): if isinstance(env_id, slice): env_id = torch.arange(env_id.start, env_id.stop, device=data.device).int() - data.stored_indices[batch_rows] = env_id + #data.stored_indices[batch_rows] = env_id exp.obs[batch_rows, l] = obs exp.actions[batch_rows, l] = action diff --git a/config/ocean/nmmo3.ini b/config/ocean/nmmo3.ini index cad39fe36..34be2a02f 100644 --- a/config/ocean/nmmo3.ini +++ b/config/ocean/nmmo3.ini @@ -1,7 +1,7 @@ [base] package = ocean env_name = puffer_nmmo3 -vec = multiprocessing +vec = serial policy_name = NMMO3 rnn_name = NMMO3LSTM @@ -9,7 +9,7 @@ rnn_name = NMMO3LSTM reward_combat_level = 1.0 reward_prof_level = 1.0 reward_item_level = 1.0 -reward_market = 0 +reward_market = 0.0 reward_death = -1.0 num_envs = 4 From 9984a732e9731c9c4b474a7947ba90659c388118 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Fri, 11 Apr 2025 00:46:28 +0000 Subject: [PATCH 15/26] forgot a file --- pufferlib/ocean/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 346d2d727..26a876f12 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -60,7 +60,7 @@ def forward(self, x, state=None): def forward_train(self, x, state=None): return self.forward(x, state) - def encode_observations(self, observations, unflatten=False): + def encode_observations(self, observations, state=None): batch = observations.shape[0] try: ob_map = observations[:, :11*15*10].view(batch, 11, 15, 10) From fe4046b0aa79e8e044949f931e5e4d2e9a991f1b Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Sat, 12 Apr 2025 00:15:23 +0000 Subject: [PATCH 16/26] sync sampling --- pufferlib/vector.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 8a6726eff..b463b99e0 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -231,7 +231,7 @@ def num_envs(self): def __init__(self, env_creators, env_args, env_kwargs, num_envs, num_workers=None, batch_size=None, - zero_copy=True, overwork=False, seed=0, **kwargs): + zero_copy=True, sync_traj=True, overwork=False, seed=0, **kwargs): if batch_size is None: batch_size = num_envs if num_workers is None: @@ -340,16 +340,25 @@ def __init__(self, env_creators, env_args, env_kwargs, self.flag = RESET self.initialized = False self.zero_copy = zero_copy + self.sync_traj = sync_traj def recv(self): recv_precheck(self) while True: - worker = self.waiting_workers.pop(0) - sem = self.buf.semaphores[worker] - if sem >= MAIN: - self.ready_workers.append(worker) + # Bandaid patch for new experience buffer desync + if self.sync_traj: + worker = self.waiting_workers[0] + sem = self.buf.semaphores[worker] + if sem >= MAIN: + self.waiting_workers.pop(0) + self.ready_workers.append(worker) else: - self.waiting_workers.append(worker) + worker = self.waiting_workers.pop(0) + sem = self.buf.semaphores[worker] + if sem >= MAIN: + self.ready_workers.append(worker) + else: + self.waiting_workers.append(worker) if sem == INFO: self.infos[worker] = self.recv_pipes[worker].recv() @@ -438,6 +447,7 @@ def async_reset(self, seed=0): self.flag = RECV self.ready_workers = [] + self.ready_next_workers = [] # Used to evenly sample workers self.waiting_workers = list(range(self.num_workers)) self.infos = [[] for _ in range(self.num_workers)] From 208b987e53f4ce4dcd2a112f1fc6dfb73ceb4178 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Sat, 12 Apr 2025 00:16:00 +0000 Subject: [PATCH 17/26] Fix breakout config --- config/ocean/breakout.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/ocean/breakout.ini b/config/ocean/breakout.ini index f8d5fcd2c..bde999ee2 100644 --- a/config/ocean/breakout.ini +++ b/config/ocean/breakout.ini @@ -22,7 +22,7 @@ num_envs = 2 num_workers = 2 env_batch_size = 1 batch_size = 524288 -update_epochs = 64 +update_epochs = 1 ent_coef = 0.004602497836498393 gae_lambda = 0.8345374031042396 gamma = 0.9964277976817042 From 8441567ac4f037a8939c7f39ca762c783e461064 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Sat, 12 Apr 2025 00:36:21 +0000 Subject: [PATCH 18/26] Fix 1 char bug of doom --- clean_pufferl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 29f39146c..efe780b26 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -215,7 +215,7 @@ def evaluate(data): lstm_h = data.lstm_h lstm_c = data.lstm_c - while data.free_idx <= data.on_policy_rows: + while data.free_idx < data.on_policy_rows: with profile.env: o, r, d, t, info, env_id, mask = data.vecenv.recv() From 86c750b88b3826a5a4b917e96e9f325547b6b2fe Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Sat, 12 Apr 2025 00:38:58 +0000 Subject: [PATCH 19/26] nmmo3 conf --- config/ocean/nmmo3.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/ocean/nmmo3.ini b/config/ocean/nmmo3.ini index 34be2a02f..08681c349 100644 --- a/config/ocean/nmmo3.ini +++ b/config/ocean/nmmo3.ini @@ -1,7 +1,7 @@ [base] package = ocean env_name = puffer_nmmo3 -vec = serial +vec = multiprocessing policy_name = NMMO3 rnn_name = NMMO3LSTM @@ -26,7 +26,7 @@ gae_lambda = 0.996005622445478 ent_coef = 0.01210084358004069 max_grad_norm = 0.6075578331947327 vf_coef = 0.3979089612467003 -bptt_horizon = 16 +bptt_horizon = 32 batch_size = 262144 minibatch_size = 32768 compile = False From 6bfa9879e5ce989be19a447c47fcc84baaa6e529 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Sat, 12 Apr 2025 21:00:49 +0000 Subject: [PATCH 20/26] vtrace --- clean_pufferl.py | 78 +++++++++++++++++++++++++++++++++++----------- config/default.ini | 5 +++ pufferlib.cpp | 23 ++++++++++++++ pufferlib.cu | 64 +++++++++++++++++++++++++++++++------ shared.cpp | 40 ++++++++++++++++++++++++ 5 files changed, 181 insertions(+), 29 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index efe780b26..ef73209fa 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -30,11 +30,13 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): torch.manual_seed(config.seed) ext = 'cu' if 'cuda' in config.device else 'cpp' - compute_gae = load( - name='compute_gae', + puffer_cuda = load( + name='puffer_cuda', sources=[f'pufferlib.{ext}'], verbose=True - ).compute_gae + ) + compute_gae = puffer_cuda.compute_gae + compute_vtrace = puffer_cuda.compute_vtrace losses = pufferlib.namespace( policy_loss=0, @@ -200,6 +202,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): device=config.device, minibatch_size=minibatch_size, compute_gae=compute_gae, + compute_vtrace=compute_vtrace, diayn_skills=diayn_skills, total_agents=total_agents, ) @@ -365,6 +368,8 @@ def train(data): advantages = advantages.cpu().numpy() torch.cuda.synchronize() + elif config.use_vtrace: + advantages = torch.ones(experience.values.shape, device=config.device) else: advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) @@ -413,14 +418,12 @@ def train(data): # TODO: Currently only returning traj shaped value as a hack logits, newvalue = data.policy.forward_train(batch.obs, state) - with torch.no_grad(): - experience.values[batch.idx] = newvalue - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=batch.actions, is_continuous=data.policy.is_continuous) with profile.train_misc: - logratio = newlogprob - batch.logprobs.reshape(-1) + newlogprob = newlogprob.reshape(batch.logprobs.shape) + logratio = newlogprob - batch.logprobs ratio = logratio.exp() # TODO: Only do this if we are KL clipping? Saves 1-2% compute @@ -430,16 +433,47 @@ def train(data): approx_kl = ((ratio - 1) - logratio).mean() clipfrac = ((ratio - 1.0).abs() > config.clip_coef).float().mean() - adv = batch.advantages.reshape(-1) - if config.norm_adv: - adv = (adv - adv.mean()) / (adv.std() + 1e-8) + if config.use_vtrace: + with torch.no_grad(): + vs = torch.zeros(batch.values.shape, device=config.device) + adv = torch.zeros(batch.values.shape, device=config.device) + data.compute_vtrace(batch.values, batch.rewards, batch.dones, + ratio, vs, adv, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) + batch.returns = vs + + # Might need returns at next step + #pg_loss = (newlogprob*(batch.rewards + config.gamma*batch.returns - batch.values)).mean() + #clipped_rho = torch.clamp(ratio, max=config.vtrace_rho_clip)[:, :-1] + #adv = clipped_rho * (batch.rewards[:, :-1] + config.gamma*batch.returns[:, 1:] - batch.values[:, :-1]) + + #lgt = logits.reshape(*newlogprob.shape, logits.shape[-1]) + #lgt = lgt[:, :-1].reshape(-1, lgt.shape[-1]) + #atns = batch.actions[:, :-1].reshape(-1) + #adv = adv.reshape(-1) + #pg_loss = (adv * torch.nn.functional.cross_entropy(lgt, atns, reduction='none')).mean() + #print(torch.mean(batch.values), torch.mean(batch.returns), torch.mean(adv)) + lgt = logits.reshape(-1, logits.shape[-1]) + atns = batch.actions.reshape(-1) + adv = adv.reshape(-1) + #pg_loss = (adv * torch.nn.functional.cross_entropy(lgt, atns, reduction='none')).mean() + + #if config.norm_adv: + # adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + pg_loss = torch.mean(adv * torch.nn.functional.nll_loss( + torch.nn.functional.log_softmax(lgt, dim=-1), target=atns, reduction='none')) - # Policy loss - pg_loss1 = -adv * ratio - pg_loss2 = -adv * torch.clamp( - ratio, 1 - config.clip_coef, 1 + config.clip_coef - ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() + else: + adv = batch.advantages + if config.norm_adv: + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Policy loss + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp( + ratio, 1 - config.clip_coef, 1 + config.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss if config.use_p3o: @@ -452,10 +486,10 @@ def train(data): mask_block = mask_block[:, :(horizon+3)] v_loss = v_loss[mask_block.bool()].mean() elif config.clip_vloss: - newvalue = newvalue.flatten() - ret = batch.returns.flatten() + newvalue = newvalue#.flatten() + ret = batch.returns#.flatten() v_loss_unclipped = (newvalue - ret) ** 2 - val = batch.values.flatten() + val = batch.values#.flatten() v_clipped = val + torch.clamp( newvalue - val, -config.vf_clip_coef, @@ -471,6 +505,11 @@ def train(data): entropy_loss = entropy.mean() loss += pg_loss - config.ent_coef*entropy_loss + v_loss*config.vf_coef + # This breaks vloss clipping? + with torch.no_grad(): + experience.values[batch.idx] = newvalue + + with profile.learn: if data.scaler is not None: loss = data.scaler.scale(loss) @@ -607,6 +646,7 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): def sample(data, advantages, n, reward_block=None, mask_block=None, method='multinomial'): exp = data.experience + method = 'random' if method == 'topk': _, idx = torch.topk(advantages.abs().sum(axis=1), n) elif method == 'multinomial': diff --git a/config/default.ini b/config/default.ini index b38b1f22a..4f06094ad 100644 --- a/config/default.ini +++ b/config/default.ini @@ -29,6 +29,7 @@ gamma = 0.995 gae_lambda = 0.85 update_epochs = 1 norm_adv = True +# Consider raising clip coef to 0.2 clip_coef = 0.1 clip_vloss = True vf_coef = 2.0 @@ -70,6 +71,10 @@ use_p3o = False p3o_horizon = 128 puf = 0.0 +use_vtrace = False +vtrace_rho_clip = 1.0 +vtrace_c_clip = 1.0 + [sweep] method = protein name = sweep diff --git a/pufferlib.cpp b/pufferlib.cpp index 10e39fcbb..4889030eb 100644 --- a/pufferlib.cpp +++ b/pufferlib.cpp @@ -21,6 +21,29 @@ torch::Tensor compute_gae(torch::Tensor values, torch::Tensor rewards, return advantages; } +// [num_steps, horizon] +void vtrace(float* values, float* rewards, float* dones, float* importance, + float* trace, float gamma, float rho_clip, float c_clip, int num_steps, const int horizon){ + for (int offset = 0; offset < num_steps*horizon; offset+=horizon) { + vtrace_row(values + offset, rewards + offset, dones + offset, + importance + offset, trace + offset, gamma, rho_clip, c_clip, horizon); + } +} + +torch::Tensor compute_vtrace(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, float gamma, + float rho_clip, float c_clip) { + int num_steps = values.size(0); + int horizon = values.size(1); + torch::Tensor trace = vtrace_check(values, rewards, dones, importance, num_steps, horizon); + vtrace(values.data_ptr(), rewards.data_ptr(), + dones.data_ptr(), importance.data_ptr(), + trace.data_ptr(), gamma, rho_clip, c_clip, num_steps, horizon + ); + return trace; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_gae", &compute_gae, "Compute GAE with C"); + m.def("compute_vtrace", &compute_vtrace, "Compute VTrace with C"); } diff --git a/pufferlib.cu b/pufferlib.cu index db62686ec..fa4c2f919 100644 --- a/pufferlib.cu +++ b/pufferlib.cu @@ -89,15 +89,6 @@ __global__ void p3o_kernel( } -// [num_steps, horizon] -__global__ void gae_kernel(float* values, float* rewards, float* dones, - float* advantages, float gamma, float gae_lambda, int num_steps, int horizon) { - int row = blockIdx.x*blockDim.x + threadIdx.x; - int offset = row*horizon; - gae_row(values + offset, rewards + offset, dones + offset, - advantages + offset, gamma, gae_lambda, horizon); -} - void compute_p3o(torch::Tensor reward_block, torch::Tensor reward_mask, torch::Tensor values_mean, torch::Tensor values_std, torch::Tensor buf, torch::Tensor dones, torch::Tensor rewards, torch::Tensor advantages, @@ -155,6 +146,15 @@ void compute_p3o(torch::Tensor reward_block, torch::Tensor reward_mask, return; } +// [num_steps, horizon] +__global__ void gae_kernel(float* values, float* rewards, float* dones, + float* advantages, float gamma, float gae_lambda, int num_steps, int horizon) { + int row = blockIdx.x*blockDim.x + threadIdx.x; + int offset = row*horizon; + gae_row(values + offset, rewards + offset, dones + offset, + advantages + offset, gamma, gae_lambda, horizon); +} + torch::Tensor compute_gae(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, float gamma, float gae_lambda) { int num_steps = values.size(0); @@ -164,6 +164,7 @@ torch::Tensor compute_gae(torch::Tensor values, torch::Tensor rewards, int threads_per_block = 256; int blocks = (num_steps + threads_per_block - 1) / threads_per_block; + assert(num_steps % threads_per_block == 0); gae_kernel<<>>( values.data_ptr(), @@ -183,9 +184,52 @@ torch::Tensor compute_gae(torch::Tensor values, torch::Tensor rewards, return advantages; } - + + // [num_steps, horizon] +__global__ void vtrace_kernel(float* values, float* rewards, float* dones, float* importance, + float* vs, float* advantages, float gamma, float rho_clip, float c_clip, int num_steps, int horizon) { + int row = blockIdx.x*blockDim.x + threadIdx.x; + int offset = row*horizon; + vtrace_row(values + offset, rewards + offset, dones + offset, + importance + offset, vs + offset, advantages + offset, gamma, rho_clip, c_clip, horizon); +} + +void compute_vtrace(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, torch::Tensor vs, torch::Tensor advantages, + float gamma, float rho_clip, float c_clip) { + int num_steps = values.size(0); + int horizon = values.size(1); + vtrace_check(values, rewards, dones, importance, vs, advantages, num_steps, horizon); + TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU"); + assert(horizon <= max_horizon); + + int threads_per_block = 128; + int blocks = (num_steps + threads_per_block - 1) / threads_per_block; + assert(num_steps % threads_per_block == 0); + + vtrace_kernel<<>>( + values.data_ptr(), + rewards.data_ptr(), + dones.data_ptr(), + importance.data_ptr(), + vs.data_ptr(), + advantages.data_ptr(), + gamma, + rho_clip, + c_clip, + num_steps, + horizon + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error(cudaGetErrorString(err)); + } +} + // Pybind11 module definition PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_p3o", &compute_p3o, "Compute p3o advantages with CUDA"); m.def("compute_gae", &compute_gae, "Compute GAE with CUDA"); + m.def("compute_vtrace", &compute_vtrace, "Compute VTrace with CUDA"); } diff --git a/shared.cpp b/shared.cpp index fd3202cb5..2f8107ea5 100644 --- a/shared.cpp +++ b/shared.cpp @@ -13,6 +13,7 @@ __host__ __device__ void gae_row(float* values, float* rewards, float* dones, fl for (int t = horizon-2; t >= 0; t--) { int t_next = t + 1; float nextnonterminal = 1.0 - dones[t_next]; + // Should this rewards[t_next] be rewards[t]? float delta = rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]; lastgaelam = delta + gamma*gae_lambda*nextnonterminal * lastgaelam; advantages[t] = lastgaelam; @@ -44,4 +45,43 @@ torch::Tensor gae_check(torch::Tensor values, torch::Tensor rewards, return advantages; } +// [horizon] +const int max_horizon = 256; +__host__ __device__ void vtrace_row(float* values, float* rewards, float* dones, + float* importance, float* vs, float* advantages, float gamma, float rho_clip, float c_clip, int horizon) { + float accum = 0.0;//values[horizon-1]; // Is this correct? + vs[horizon-1] = values[horizon-1]; + for (int t = horizon-2; t >= 0; t--) { + int t_next = t + 1; + float nextnonterminal = 1.0 - dones[t_next]; + float rho_t = fminf(importance[t], rho_clip); + float c_t = fminf(importance[t], c_clip); + float delta = rho_t*(rewards[t] + gamma*values[t_next]*nextnonterminal - values[t]); + accum = delta + gamma*c_t*accum*nextnonterminal; + advantages[t] = rho_t*(rewards[t] + gamma*vs[t_next]*nextnonterminal - values[t]); + vs[t] = accum + values[t]; + } +} + +void vtrace_check(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, torch::Tensor vs, torch::Tensor advantages, + int num_steps, int horizon) { + + // Validate input tensors + torch::Device device = values.device(); + for (const torch::Tensor& t : {values, rewards, dones, importance, vs, advantages}) { + TORCH_CHECK(t.dim() == 2, "Tensor must be 2D"); + TORCH_CHECK(t.device() == device, "All tensors must be on same device"); + TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps"); + TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon"); + TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32"); + assert(horizon <= max_horizon); + if (!t.is_contiguous()) { + t.contiguous(); + } + } +} + + + From 541f983459a00db56adcf6e0b3c7897e1e5afbc0 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Sun, 13 Apr 2025 00:18:09 +0000 Subject: [PATCH 21/26] prioritized replay --- clean_pufferl.py | 42 +++++++++++++++++++++++++----------------- config/default.ini | 3 +++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index ef73209fa..150c39827 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -103,6 +103,9 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): else: experience.values = torch.zeros(experience_rows, config.bptt_horizon, device=config.device) + if config.use_vtrace: + experience.importance = torch.ones(experience_rows, config.bptt_horizon, device=config.device) + lstm_h = None lstm_c = None if isinstance(policy, torch.nn.LSTM): @@ -205,6 +208,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): compute_vtrace=compute_vtrace, diayn_skills=diayn_skills, total_agents=total_agents, + total_epochs=epochs, ) @pufferlib.utils.profile @@ -369,13 +373,14 @@ def train(data): advantages = advantages.cpu().numpy() torch.cuda.synchronize() elif config.use_vtrace: - advantages = torch.ones(experience.values.shape, device=config.device) + advantages = torch.ones(experience.values.shape, device=config.device).to(config.device) + importance = experience.importance else: - advantages = data.compute_gae(experience.values, experience.rewards, + importance = advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) with profile.train_copy: - batch = sample(data, advantages, n_samples) + batch = sample(data, importance, n_samples) loss = 0 with profile.custom: @@ -441,21 +446,11 @@ def train(data): ratio, vs, adv, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) batch.returns = vs + importance[batch.idx] = adv # Might need returns at next step - #pg_loss = (newlogprob*(batch.rewards + config.gamma*batch.returns - batch.values)).mean() - #clipped_rho = torch.clamp(ratio, max=config.vtrace_rho_clip)[:, :-1] - #adv = clipped_rho * (batch.rewards[:, :-1] + config.gamma*batch.returns[:, 1:] - batch.values[:, :-1]) - - #lgt = logits.reshape(*newlogprob.shape, logits.shape[-1]) - #lgt = lgt[:, :-1].reshape(-1, lgt.shape[-1]) - #atns = batch.actions[:, :-1].reshape(-1) - #adv = adv.reshape(-1) - #pg_loss = (adv * torch.nn.functional.cross_entropy(lgt, atns, reduction='none')).mean() - #print(torch.mean(batch.values), torch.mean(batch.returns), torch.mean(adv)) lgt = logits.reshape(-1, logits.shape[-1]) atns = batch.actions.reshape(-1) - adv = adv.reshape(-1) - #pg_loss = (adv * torch.nn.functional.cross_entropy(lgt, atns, reduction='none')).mean() + adv = (batch.prio*adv).reshape(-1) #if config.norm_adv: # adv = (adv - adv.mean()) / (adv.std() + 1e-8) @@ -468,6 +463,8 @@ def train(data): if config.norm_adv: adv = (adv - adv.mean()) / (adv.std() + 1e-8) + adv = adv * batch.prio + # Policy loss pg_loss1 = -adv * ratio pg_loss2 = -adv * torch.clamp( @@ -561,6 +558,7 @@ def train(data): advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) exp = sample(data, advantages, data.off_policy_rows, method='topk') + experience.importance[:data.on_policy_rows] = 1 for k, v in experience.items(): v[data.on_policy_rows:] = exp[k] @@ -644,11 +642,15 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): data.step += 1 return action.cpu().numpy() -def sample(data, advantages, n, reward_block=None, mask_block=None, method='multinomial'): +def sample(data, advantages, n, reward_block=None, mask_block=None, method='prio'): exp = data.experience - method = 'random' if method == 'topk': _, idx = torch.topk(advantages.abs().sum(axis=1), n) + elif method == 'prio': + adv = advantages.abs().sum(axis=1) + probs = adv**data.config.prio_alpha + probs = (probs + 1e-6)/(probs.sum() + 1e-6) + idx = torch.multinomial(probs, n) elif method == 'multinomial': idx = torch.multinomial(advantages.abs().sum(axis=1) + 1e-6, n) elif method == 'random': @@ -656,6 +658,7 @@ def sample(data, advantages, n, reward_block=None, mask_block=None, method='mult else: raise ValueError(f'Unknown sampling method: {method}') + data.ep_uses[idx] += 1 output = {k: v[idx] for k, v in exp.items()} output['idx'] = idx @@ -673,6 +676,11 @@ def sample(data, advantages, n, reward_block=None, mask_block=None, method='mult if data.use_diayn: output['diayn_z'] = exp.diayn_batch[idx] + output['prio'] = 1 + if method == 'prio': + beta = data.config.prio_beta0 + (1 - data.config.prio_beta0)*data.config.prio_alpha*data.epoch/data.total_epochs + output['prio'] = (((1/len(probs)) * (1/probs[idx]))**beta).unsqueeze(1).expand_as(output['advantages']) + return pufferlib.namespace(**output) def dist_sum(value, device): diff --git a/config/default.ini b/config/default.ini index 4f06094ad..e005368c1 100644 --- a/config/default.ini +++ b/config/default.ini @@ -75,6 +75,9 @@ use_vtrace = False vtrace_rho_clip = 1.0 vtrace_c_clip = 1.0 +prio_alpha = 0.6 +prio_beta0 = 0.4 + [sweep] method = protein name = sweep From a94a7f3b9a1e6b6af31a913f10c9aba32ad7cb1e Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 15 Apr 2025 17:16:32 +0000 Subject: [PATCH 22/26] Initial puffer advantage --- clean_pufferl.py | 53 +++++++++++++++++++++++++++++++++++++++------- config/default.ini | 2 ++ pufferlib.cpp | 25 ++++++++++++++++++++++ pufferlib.cu | 46 ++++++++++++++++++++++++++++++++++++++++ shared.cpp | 17 +++++++++++++++ 5 files changed, 135 insertions(+), 8 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 150c39827..5e78f0c85 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -37,6 +37,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): ) compute_gae = puffer_cuda.compute_gae compute_vtrace = puffer_cuda.compute_vtrace + compute_puff_advantage = puffer_cuda.compute_puff_advantage losses = pufferlib.namespace( policy_loss=0, @@ -103,7 +104,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): else: experience.values = torch.zeros(experience_rows, config.bptt_horizon, device=config.device) - if config.use_vtrace: + if config.use_vtrace or config.use_puff_advantage: experience.importance = torch.ones(experience_rows, config.bptt_horizon, device=config.device) lstm_h = None @@ -206,6 +207,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): minibatch_size=minibatch_size, compute_gae=compute_gae, compute_vtrace=compute_vtrace, + compute_puff_advantage=compute_puff_advantage, diayn_skills=diayn_skills, total_agents=total_agents, total_epochs=epochs, @@ -372,7 +374,7 @@ def train(data): advantages = advantages.cpu().numpy() torch.cuda.synchronize() - elif config.use_vtrace: + elif config.use_vtrace or config.use_puff_advantage: advantages = torch.ones(experience.values.shape, device=config.device).to(config.device) importance = experience.importance else: @@ -438,25 +440,60 @@ def train(data): approx_kl = ((ratio - 1) - logratio).mean() clipfrac = ((ratio - 1.0).abs() > config.clip_coef).float().mean() - if config.use_vtrace: + if config.use_vtrace or config.use_puff_advantage: with torch.no_grad(): vs = torch.zeros(batch.values.shape, device=config.device) adv = torch.zeros(batch.values.shape, device=config.device) - data.compute_vtrace(batch.values, batch.rewards, batch.dones, - ratio, vs, adv, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) + if config.use_vtrace: + data.compute_vtrace(batch.values, batch.rewards, batch.dones, + ratio, vs, adv, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) + elif config.use_puff_advantage: + data.compute_puff_advantage(batch.values, batch.rewards, batch.dones, + ratio, vs, adv, config.gamma, config.gae_lambda, config.vtrace_rho_clip, config.vtrace_c_clip) batch.returns = vs importance[batch.idx] = adv # Might need returns at next step lgt = logits.reshape(-1, logits.shape[-1]) atns = batch.actions.reshape(-1) + + if config.norm_adv: + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + adv = (batch.prio*adv).reshape(-1) + ratio = ratio.view(-1) + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp( + ratio, 1 - config.clip_coef, 1 + config.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + #if config.norm_adv: # adv = (adv - adv.mean()) / (adv.std() + 1e-8) - - pg_loss = torch.mean(adv * torch.nn.functional.nll_loss( - torch.nn.functional.log_softmax(lgt, dim=-1), target=atns, reduction='none')) + #nll_loss = torch.nn.functional.nll_loss( + # torch.nn.functional.log_softmax(lgt, dim=-1), target=atns, reduction='none') + + # Worse than nll_loss + #ratio = ratio.view(-1) + #pg_loss1 = -adv * ratio + #pg_loss2 = -adv * torch.clamp( + # ratio, 1 - config.clip_coef, 1 + config.clip_coef + #) + #pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + + #adv = torch.clamp(adv, 1-config.clip_coef, 1+config.clip_coef) + #pg_loss = (adv * nll_loss).mean() + #pg_loss2 = adv * torch.clamp( + # nll_loss, 1 - config.clip_coef, 1 + config.clip_coef + #) + #pg_loss = torch.max(pg_loss1, pg_loss2).mean() + #pg_loss = pg_loss1.mean() + + #pg_loss = torch.mean(adv * torch.nn.functional.nll_loss( + # torch.nn.functional.log_softmax(lgt, dim=-1), target=atns, reduction='none')) else: adv = batch.advantages diff --git a/config/default.ini b/config/default.ini index e005368c1..6ffa7a627 100644 --- a/config/default.ini +++ b/config/default.ini @@ -75,6 +75,8 @@ use_vtrace = False vtrace_rho_clip = 1.0 vtrace_c_clip = 1.0 +use_puff_advantage = False + prio_alpha = 0.6 prio_beta0 = 0.4 diff --git a/pufferlib.cpp b/pufferlib.cpp index 4889030eb..03fac669b 100644 --- a/pufferlib.cpp +++ b/pufferlib.cpp @@ -30,6 +30,16 @@ void vtrace(float* values, float* rewards, float* dones, float* importance, } } +// [num_steps, horizon] +void puff_advantage(float* values, float* rewards, float* dones, float* importance, + float* trace, float gamma, float rho_clip, float c_clip, int num_steps, const int horizon){ + for (int offset = 0; offset < num_steps*horizon; offset+=horizon) { + puff_row(values + offset, rewards + offset, dones + offset, + importance + offset, trace + offset, gamma, rho_clip, c_clip, horizon); + } +} + + torch::Tensor compute_vtrace(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, torch::Tensor importance, float gamma, float rho_clip, float c_clip) { @@ -43,7 +53,22 @@ torch::Tensor compute_vtrace(torch::Tensor values, torch::Tensor rewards, return trace; } +torch::Tensor compute_puff_advantage(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, float gamma, + float rho_clip, float c_clip) { + int num_steps = values.size(0); + int horizon = values.size(1); + torch::Tensor trace = vtrace_check(values, rewards, dones, importance, num_steps, horizon); + vtrace(values.data_ptr(), rewards.data_ptr(), + dones.data_ptr(), importance.data_ptr(), + trace.data_ptr(), gamma, rho_clip, c_clip, num_steps, horizon + ); + return trace; +} + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_gae", &compute_gae, "Compute GAE with C"); m.def("compute_vtrace", &compute_vtrace, "Compute VTrace with C"); + m.def("compute_puff_advantage", &compute_puff_advantage, "Compute PuffAdvantage with C"); } diff --git a/pufferlib.cu b/pufferlib.cu index fa4c2f919..ce65fe66a 100644 --- a/pufferlib.cu +++ b/pufferlib.cu @@ -227,9 +227,55 @@ void compute_vtrace(torch::Tensor values, torch::Tensor rewards, } } + // [num_steps, horizon] +__global__ void puff_advantage_kernel(float* values, float* rewards, float* dones, float* importance, + float* vs, float* advantages, float gamma, float lambda, + float rho_clip, float c_clip, int num_steps, int horizon) { + int row = blockIdx.x*blockDim.x + threadIdx.x; + int offset = row*horizon; + puff_advantage_row(values + offset, rewards + offset, dones + offset, + importance + offset, vs + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon); +} + +void compute_puff_advantage(torch::Tensor values, torch::Tensor rewards, + torch::Tensor dones, torch::Tensor importance, torch::Tensor vs, torch::Tensor advantages, + float gamma, float lambda, float rho_clip, float c_clip) { + int num_steps = values.size(0); + int horizon = values.size(1); + vtrace_check(values, rewards, dones, importance, vs, advantages, num_steps, horizon); + TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU"); + assert(horizon <= max_horizon); + + int threads_per_block = 128; + int blocks = (num_steps + threads_per_block - 1) / threads_per_block; + assert(num_steps % threads_per_block == 0); + + puff_advantage_kernel<<>>( + values.data_ptr(), + rewards.data_ptr(), + dones.data_ptr(), + importance.data_ptr(), + vs.data_ptr(), + advantages.data_ptr(), + gamma, + lambda, + rho_clip, + c_clip, + num_steps, + horizon + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error(cudaGetErrorString(err)); + } +} + + // Pybind11 module definition PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_p3o", &compute_p3o, "Compute p3o advantages with CUDA"); m.def("compute_gae", &compute_gae, "Compute GAE with CUDA"); m.def("compute_vtrace", &compute_vtrace, "Compute VTrace with CUDA"); + m.def("compute_puff_advantage", &compute_puff_advantage, "Compute PuffAdvantage with CUDA"); } diff --git a/shared.cpp b/shared.cpp index 2f8107ea5..c34d2c1be 100644 --- a/shared.cpp +++ b/shared.cpp @@ -63,6 +63,23 @@ __host__ __device__ void vtrace_row(float* values, float* rewards, float* dones, } } +__host__ __device__ void puff_advantage_row(float* values, float* rewards, float* dones, + float* importance, float* vs, float* advantages, float gamma, float lambda, + float rho_clip, float c_clip, int horizon) { + vs[horizon-1] = values[horizon-1]; + float lastpufferlam = 0; + for (int t = horizon-2; t >= 0; t--) { + int t_next = t + 1; + float nextnonterminal = 1.0 - dones[t_next]; + float rho_t = fminf(importance[t], rho_clip); + float c_t = fminf(importance[t], c_clip); + float delta = rho_t*(rewards[t] + gamma*values[t_next]*nextnonterminal - values[t]); + lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal; + advantages[t] = rho_t*(rewards[t] + gamma*vs[t_next]*nextnonterminal - values[t]); + vs[t] = lastpufferlam + values[t]; + } +} + void vtrace_check(torch::Tensor values, torch::Tensor rewards, torch::Tensor dones, torch::Tensor importance, torch::Tensor vs, torch::Tensor advantages, int num_steps, int horizon) { From b87d796d7b65d284f1b924fce40330e5253b8c48 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 15 Apr 2025 18:01:59 +0000 Subject: [PATCH 23/26] Working puff advantage --- clean_pufferl.py | 33 ++++++++++++++++++++++----------- pufferlib/ocean/moba/moba.py | 2 +- pufferlib/ocean/rware/rware.py | 2 +- shared.cpp | 14 +++++++++++--- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 5e78f0c85..87e77cc82 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -340,6 +340,7 @@ def train(data): total_minibatches = int(config.update_epochs*config.batch_size/data.minibatch_size) accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) n_samples = data.minibatch_size // config.bptt_horizon + experience.ratio = torch.ones(experience.values.shape, device=config.device).to(config.device) for mb in range(total_minibatches): with profile.train_misc: if config.use_p3o: @@ -374,9 +375,16 @@ def train(data): advantages = advantages.cpu().numpy() torch.cuda.synchronize() - elif config.use_vtrace or config.use_puff_advantage: - advantages = torch.ones(experience.values.shape, device=config.device).to(config.device) - importance = experience.importance + elif config.use_vtrace: + importance = advantages = torch.zeros(experience.values.shape, device=config.device).to(config.device) + vs = torch.zeros(experience.values.shape, device=config.device) + data.compute_vtrace(batch.values, batch.rewards, batch.dones, + experience.ratio, vs, advantages, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) + elif config.use_puff_advantage: + importance = advantages = torch.zeros(experience.values.shape, device=config.device).to(config.device) + vs = torch.zeros(experience.values.shape, device=config.device) + data.compute_puff_advantage(experience.values, experience.rewards, experience.dones, + experience.ratio, vs, advantages, config.gamma, config.gae_lambda, config.vtrace_rho_clip, config.vtrace_c_clip) else: importance = advantages = data.compute_gae(experience.values, experience.rewards, experience.dones, config.gamma, config.gae_lambda) @@ -432,6 +440,7 @@ def train(data): newlogprob = newlogprob.reshape(batch.logprobs.shape) logratio = newlogprob - batch.logprobs ratio = logratio.exp() + experience.ratio[batch.idx] = ratio # TODO: Only do this if we are KL clipping? Saves 1-2% compute with torch.no_grad(): @@ -442,27 +451,29 @@ def train(data): if config.use_vtrace or config.use_puff_advantage: with torch.no_grad(): - vs = torch.zeros(batch.values.shape, device=config.device) - adv = torch.zeros(batch.values.shape, device=config.device) + adv = advantages[batch.idx] + vs = vs[batch.idx] if config.use_vtrace: data.compute_vtrace(batch.values, batch.rewards, batch.dones, ratio, vs, adv, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) elif config.use_puff_advantage: data.compute_puff_advantage(batch.values, batch.rewards, batch.dones, ratio, vs, adv, config.gamma, config.gae_lambda, config.vtrace_rho_clip, config.vtrace_c_clip) - batch.returns = vs - importance[batch.idx] = adv + #advantages[batch.idx] = adv + #importance[batch.idx] = adv + # Might need returns at next step - lgt = logits.reshape(-1, logits.shape[-1]) - atns = batch.actions.reshape(-1) + #lgt = logits.reshape(-1, logits.shape[-1]) + #atns = batch.actions.reshape(-1) + + adv = advantages[batch.idx] if config.norm_adv: adv = (adv - adv.mean()) / (adv.std() + 1e-8) - adv = (batch.prio*adv).reshape(-1) + adv = adv * batch.prio - ratio = ratio.view(-1) pg_loss1 = -adv * ratio pg_loss2 = -adv * torch.clamp( ratio, 1 - config.clip_coef, 1 + config.clip_coef diff --git a/pufferlib/ocean/moba/moba.py b/pufferlib/ocean/moba/moba.py index 51ff827cf..00acfffac 100644 --- a/pufferlib/ocean/moba/moba.py +++ b/pufferlib/ocean/moba/moba.py @@ -16,7 +16,7 @@ class Moba(pufferlib.PufferEnv): def __init__(self, num_envs=4, vision_range=5, agent_speed=1.0, discretize=True, reward_death=-1.0, reward_xp=0.006, reward_distance=0.05, reward_tower=3.0, report_interval=32, - script_opponents=True, render_mode='human', buf=None): + script_opponents=True, render_mode='human', buf=None, seed=0): self.report_interval = report_interval self.render_mode = render_mode diff --git a/pufferlib/ocean/rware/rware.py b/pufferlib/ocean/rware/rware.py index f0c6cfb96..6caba351f 100644 --- a/pufferlib/ocean/rware/rware.py +++ b/pufferlib/ocean/rware/rware.py @@ -22,7 +22,7 @@ def __init__(self, num_envs=1, render_mode=None, report_interval=1, grid_square_size=64, human_agent_idx=0, reward_type=1, - buf = None): + buf = None, seed=0): # env self.num_agents = num_envs*num_agents diff --git a/shared.cpp b/shared.cpp index c34d2c1be..121656477 100644 --- a/shared.cpp +++ b/shared.cpp @@ -73,10 +73,18 @@ __host__ __device__ void puff_advantage_row(float* values, float* rewards, float float nextnonterminal = 1.0 - dones[t_next]; float rho_t = fminf(importance[t], rho_clip); float c_t = fminf(importance[t], c_clip); - float delta = rho_t*(rewards[t] + gamma*values[t_next]*nextnonterminal - values[t]); + // TODO: t_next works and t doesn't. Check original formula + float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]); lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal; - advantages[t] = rho_t*(rewards[t] + gamma*vs[t_next]*nextnonterminal - values[t]); - vs[t] = lastpufferlam + values[t]; + + //float delta = rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]; + //lastpufferlam = delta + gamma*lambda*lastpufferlam*nextnonterminal; + + + advantages[t] = lastpufferlam; + vs[t] = advantages[t] + values[t]; + //advantages[t] = rho_t*(rewards[t] + gamma*vs[t_next]*nextnonterminal - values[t]); + //vs[t] = lastpufferlam + values[t]; } } From 3d489df32d9510fe908adc5aecde42c8a46db0a9 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Tue, 15 Apr 2025 20:37:50 +0000 Subject: [PATCH 24/26] small exp tweaks --- clean_pufferl.py | 77 +++++++++++------------------------------------- 1 file changed, 18 insertions(+), 59 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 87e77cc82..f858470aa 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -74,6 +74,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): rewards=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), dones=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), truncateds=torch.zeros(experience_rows, config.bptt_horizon, device=config.device), + ratio = torch.ones(experience_rows, config.bptt_horizon, device=config.device), ) ep_uses = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) #stored_indices = torch.zeros(experience_rows, device=config.device, dtype=torch.int32) @@ -340,7 +341,6 @@ def train(data): total_minibatches = int(config.update_epochs*config.batch_size/data.minibatch_size) accumulate_minibatches = max(1, config.minibatch_size // config.max_minibatch_size) n_samples = data.minibatch_size // config.bptt_horizon - experience.ratio = torch.ones(experience.values.shape, device=config.device).to(config.device) for mb in range(total_minibatches): with profile.train_misc: if config.use_p3o: @@ -463,62 +463,18 @@ def train(data): #advantages[batch.idx] = adv #importance[batch.idx] = adv - # Might need returns at next step - #lgt = logits.reshape(-1, logits.shape[-1]) - #atns = batch.actions.reshape(-1) + adv = batch.advantages + if config.norm_adv: + adv = (adv - adv.mean()) / (adv.std() + 1e-8) - adv = advantages[batch.idx] + adv = adv * batch.prio - if config.norm_adv: - adv = (adv - adv.mean()) / (adv.std() + 1e-8) - - adv = adv * batch.prio - - pg_loss1 = -adv * ratio - pg_loss2 = -adv * torch.clamp( - ratio, 1 - config.clip_coef, 1 + config.clip_coef - ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() - - - #if config.norm_adv: - # adv = (adv - adv.mean()) / (adv.std() + 1e-8) - #nll_loss = torch.nn.functional.nll_loss( - # torch.nn.functional.log_softmax(lgt, dim=-1), target=atns, reduction='none') - - # Worse than nll_loss - #ratio = ratio.view(-1) - #pg_loss1 = -adv * ratio - #pg_loss2 = -adv * torch.clamp( - # ratio, 1 - config.clip_coef, 1 + config.clip_coef - #) - #pg_loss = torch.max(pg_loss1, pg_loss2).mean() - - - #adv = torch.clamp(adv, 1-config.clip_coef, 1+config.clip_coef) - #pg_loss = (adv * nll_loss).mean() - #pg_loss2 = adv * torch.clamp( - # nll_loss, 1 - config.clip_coef, 1 + config.clip_coef - #) - #pg_loss = torch.max(pg_loss1, pg_loss2).mean() - #pg_loss = pg_loss1.mean() - - #pg_loss = torch.mean(adv * torch.nn.functional.nll_loss( - # torch.nn.functional.log_softmax(lgt, dim=-1), target=atns, reduction='none')) - - else: - adv = batch.advantages - if config.norm_adv: - adv = (adv - adv.mean()) / (adv.std() + 1e-8) - - adv = adv * batch.prio - - # Policy loss - pg_loss1 = -adv * ratio - pg_loss2 = -adv * torch.clamp( - ratio, 1 - config.clip_coef, 1 + config.clip_coef - ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() + # Policy loss + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp( + ratio, 1 - config.clip_coef, 1 + config.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() # Value loss if config.use_p3o: @@ -603,13 +559,16 @@ def train(data): data.max_uses = data.ep_uses.max().item() data.mean_uses = data.ep_uses.float().mean().item() if config.replay_factor > 0: - advantages = data.compute_gae(experience.values, experience.rewards, - experience.dones, config.gamma, config.gae_lambda) - exp = sample(data, advantages, data.off_policy_rows, method='topk') - experience.importance[:data.on_policy_rows] = 1 + advantages = torch.zeros(experience.values.shape, device=config.device).to(config.device) + vs = torch.zeros(experience.values.shape, device=config.device) + data.compute_puff_advantage(experience.values, experience.rewards, experience.dones, + experience.ratio, vs, advantages, config.gamma, config.gae_lambda, config.vtrace_rho_clip, config.vtrace_c_clip) + + exp = sample(data, advantages, data.off_policy_rows) for k, v in experience.items(): v[data.on_policy_rows:] = exp[k] + experience.ratio[:data.on_policy_rows] = 1 with profile.train_misc: if config.anneal_lr: From 3986902ccfae06cfc614bcc48be95678cc29df99 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Wed, 16 Apr 2025 00:30:14 +0000 Subject: [PATCH 25/26] clean puffer advantage --- clean_pufferl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index f858470aa..cc0578556 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -49,6 +49,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): explained_variance=0, diayn_loss=0, grad_var=0, + importance=0, ) utilization = Utilization() @@ -126,7 +127,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): policy.parameters(), lr=config.learning_rate, betas=(config.adam_beta1, config.adam_beta2), - eps=config.adam_eps + eps=config.adam_eps, ) elif config.optimizer == 'muon': from heavyball import ForeachMuon @@ -136,7 +137,8 @@ def create(config, vecenv, policy, optimizer=None, wandb=None, neptune=None): policy.parameters(), lr=config.learning_rate, betas=(config.adam_beta1, config.adam_beta2), - eps=config.adam_eps + eps=config.adam_eps, + ) elif config.optimizer == 'kron': from heavyball import ForeachPSGDKron @@ -378,7 +380,7 @@ def train(data): elif config.use_vtrace: importance = advantages = torch.zeros(experience.values.shape, device=config.device).to(config.device) vs = torch.zeros(experience.values.shape, device=config.device) - data.compute_vtrace(batch.values, batch.rewards, batch.dones, + data.compute_vtrace(experience.values, experience.rewards, experience.dones, experience.ratio, vs, advantages, config.gamma, config.vtrace_rho_clip, config.vtrace_c_clip) elif config.use_puff_advantage: importance = advantages = torch.zeros(experience.values.shape, device=config.device).to(config.device) @@ -510,7 +512,6 @@ def train(data): with torch.no_grad(): experience.values[batch.idx] = newvalue - with profile.learn: if data.scaler is not None: loss = data.scaler.scale(loss) @@ -547,6 +548,7 @@ def train(data): losses.approx_kl += approx_kl.item() / total_minibatches losses.clipfrac += clipfrac.item() / total_minibatches losses.grad_var += grad_var.item() / total_minibatches + losses.importance += ratio.mean().item() / total_minibatches if data.use_diayn: losses.diayn_loss += diayn_loss.item() / total_minibatches @@ -564,10 +566,11 @@ def train(data): data.compute_puff_advantage(experience.values, experience.rewards, experience.dones, experience.ratio, vs, advantages, config.gamma, config.gae_lambda, config.vtrace_rho_clip, config.vtrace_c_clip) - exp = sample(data, advantages, data.off_policy_rows) + exp = sample(data, advantages, data.off_policy_rows, method='random') for k, v in experience.items(): v[data.on_policy_rows:] = exp[k] + #print(advantages[:data.on_policy_rows].mean(), advantages[data.on_policy_rows:].mean()) experience.ratio[:data.on_policy_rows] = 1 with profile.train_misc: From d5f53a172784add623aac10a32cf1a1d55bd124a Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Wed, 16 Apr 2025 19:45:13 +0000 Subject: [PATCH 26/26] Fix store ep len bug --- clean_pufferl.py | 4 ++-- pufferlib/models.py | 2 +- pufferlib/ocean/torch.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index cc0578556..35f767a82 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -640,8 +640,8 @@ def store(data, state, obs, value, action, logprob, reward, done, env_id, mask): # TODO: Handle masks!! #indices = np.where(mask)[0] #data.ep_lengths[env_id[mask]] += 1 - l += 1 - full = l >= data.config.bptt_horizon + data.ep_lengths[env_id] += 1 + full = data.ep_lengths[env_id] >= data.config.bptt_horizon num_full = full.sum() if num_full > 0: full_ids = env_id[full] diff --git a/pufferlib/models.py b/pufferlib/models.py index 085ff7f7c..387163611 100644 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -209,7 +209,7 @@ def forward_train(self, observations, state): flat_hidden = hidden.reshape(B*TT, self.hidden_size) logits, values = self.policy.decode_actions(flat_hidden) values = values.reshape(B, TT) - state.batch_logits = logits.reshape(B, TT, -1) + #state.batch_logits = logits.reshape(B, TT, -1) state.hidden = hidden state.lstm_h = lstm_h.detach() state.lstm_c = lstm_c.detach() diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 26a876f12..ebb90f606 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -311,12 +311,12 @@ def __init__(self, env, cnn_channels=128, hidden_size=128, **kwargs): self.value_fn = pufferlib.pytorch.layer_init( nn.Linear(hidden_size, 1), std=1) - def forward(self, observations): + def forward(self, observations, state=None): hidden, lookup = self.encode_observations(observations) actions, value = self.decode_actions(hidden, lookup) return actions, value - def encode_observations(self, observations): + def encode_observations(self, observations, state=None): cnn_features = observations[:, :-26].view(-1, 11, 11, 4).long() if cnn_features[:, :, :, 0].max() > 15: print('Invalid map value:', cnn_features[:, :, :, 0].max())