Skip to content

Commit

Permalink
remove self.model_target and added a target q ensemble only without t…
Browse files Browse the repository at this point in the history
…he need to copy the

 entire policy
  • Loading branch information
michel-aractingi committed Nov 21, 2024
1 parent a146544 commit c41ec08
Showing 1 changed file with 27 additions and 33 deletions.
60 changes: 27 additions & 33 deletions lerobot/common/policies/tdmpc2/modeling_tdmpc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv
from lerobot.common.policies.tdmpc2.tdmpc2_utils import NormedLinear, SimNorm, two_hot_inv, gaussian_logprob, squash


class TDMPC2Policy(
Expand Down Expand Up @@ -84,9 +84,6 @@ def __init__(
config = TDMPC2Config()
self.config = config
self.model = TDMPC2WorldModel(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False

if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
Expand Down Expand Up @@ -384,12 +381,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets.
z_targets = self.model_target.encode(next_observations)
z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0]
td_targets = (
reward
+ self.config.discount * self.model_target.Qs(z_targets, pi, return_type="min").squeeze()
+ self.config.discount * self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
)

# Compute losses.
Expand Down Expand Up @@ -450,10 +447,15 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
qs = self.model_target.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0])
qs = self.scale(qs)

with torch.no_grad():
# avoid unnessecary computation of the gradients during policy optimization
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0])
qs = self.scale(qs)

rho = torch.pow(self.config.rho, torch.arange(len(qs), device=qs.device)).unsqueeze(-1)

Expand Down Expand Up @@ -498,12 +500,8 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
return info

def update(self):
"""Update the target model's parameters with an EMA step."""
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)

"""Update the target model's using polyak averaging."""
self.model.update_target_Q()

class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2."""
Expand Down Expand Up @@ -586,6 +584,11 @@ def to(self, *args, **kwargs):
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
self.bins = self.bins.to(*args, **kwargs)
return self

def train(self, mode):
super().train(mode)
self._target_Qs.train(False)
return self

def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
Expand Down Expand Up @@ -622,7 +625,7 @@ def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)

def pi(self, z: Tensor, std: float = 0.0) -> Tensor:
def pi(self, z: Tensor) -> Tensor:
"""Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
Expand Down Expand Up @@ -668,6 +671,14 @@ def Qs(self, z: Tensor, a: Tensor, return_type: str = "min", target=False) -> Te
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2

def update_target_Q(self):
"""
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.config.target_model_momentum)


class TDMPC2ObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
Expand Down Expand Up @@ -777,23 +788,6 @@ def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)


def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
p_ema.mul_(alpha)
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)


def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Expand Down

0 comments on commit c41ec08

Please sign in to comment.