diff --git a/test/test_cost.py b/test/test_cost.py index c8e45624580..5d08c3447f0 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8177,18 +8177,19 @@ def _create_seq_mock_data_ppo( obs = total_obs[:, :T] next_obs = total_obs[:, 1:] if atoms: - action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( - -1, 1 - ) + action_shape = (batch, T, atoms, action_dim) else: - action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + action_shape = (batch, T, action_dim) + params_mean = torch.randn(action_shape, device=device) / 10 + params_scale = torch.rand(action_shape, device=device) / 10 + action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp( + -1, 1 + ) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) - params_mean = torch.randn_like(action) / 10 - params_scale = torch.rand_like(action) / 10 loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) if sample_log_prob_key is None: @@ -8215,9 +8216,6 @@ def _create_seq_mock_data_ppo( }, "collector": {"mask": mask}, action_key: action, - sample_log_prob_key: ( - torch.randn_like(action[..., 1]) / 10 - ).masked_fill_(~mask, 0.0), }, device=device, names=[None, "time"], diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index b8425e085b1..0876bf0c35c 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -525,10 +525,7 @@ def _log_weight( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - if isinstance(dist, CompositeDistribution): - is_composite = True - else: - is_composite = False + is_composite = isinstance(dist, CompositeDistribution) # current log_prob of actions if is_composite: @@ -545,6 +542,32 @@ def _log_weight( prev_log_prob = _maybe_get_or_select( tensordict, self.tensor_keys.sample_log_prob ) + # TODO: + # # current log_prob of actions + # action = _maybe_get_or_select(tensordict, self.tensor_keys.action) + # + # is_composite = None + # if all(key in tensordict for key in self.actor_network.dist_params_keys): + # prev_dist = self.actor_network.build_dist_from_params(tensordict.detach()) + # kwargs, is_composite = _get_composite_kwargs(prev_dist) + # if is_composite: + # prev_log_prob = prev_dist.log_prob(tensordict, **kwargs) + # else: + # prev_log_prob = prev_dist.log_prob(action, **kwargs) + # print('prev_log_prob', prev_log_prob) + # else: + # try: + # prev_log_prob = _maybe_get_or_select( + # tensordict, self.tensor_keys.sample_log_prob + # ) + # except KeyError as err: + # raise _make_lp_get_error(self.tensor_keys, tensordict, err) + + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + current_dist = self.actor_network.get_dist(tensordict) + if prev_log_prob.requires_grad: raise RuntimeError( @@ -566,20 +589,27 @@ def _log_weight( "the beginning of your script to get a proper composite log-prob.", category=UserWarning, ) - if ( - is_composite - and not is_tensor_collection(prev_log_prob) - and is_tensor_collection(log_prob) - ): - log_prob = _sum_td_features(log_prob) - log_prob.view_as(prev_log_prob) + # TODO: + # if isinstance(action, torch.Tensor): + # log_prob = current_dist.log_prob(action) + # else: + # if is_composite is None: + # kwargs, is_composite = _get_composite_kwargs(current_dist) + # log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs) + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) if is_tensor_collection(kl_approx): kl_approx = _sum_td_features(kl_approx) - return log_weight, dist, kl_approx + return log_weight, current_dist, kl_approx def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``.""" @@ -655,6 +685,9 @@ def _cached_critic_network_params_detached(self): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) + + log_weight, dist, kl_approx = self._log_weight(tensordict) + advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: self.value_estimator( @@ -675,7 +708,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) - log_weight, dist, kl_approx = self._log_weight(tensordict) if is_tensor_collection(log_weight): log_weight = _sum_td_features(log_weight) log_weight = log_weight.view(advantage.shape)