Skip to content

Commit

Permalink
added tdmpc2 to policy factory; shape fixes in tdmpc2
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-aractingi committed Nov 26, 2024
1 parent 16edbbd commit 1449014
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
7 changes: 7 additions & 0 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy

return TDMPCPolicy, TDMPCConfig

elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy

return TDMPC2Policy, TDMPC2Config

elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
Expand Down
17 changes: 6 additions & 11 deletions lerobot/common/policies/tdmpc2/modeling_tdmpc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,19 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
reward_loss = (
(
temporal_loss_coeffs
* soft_cross_entropy(reward_preds, reward, self.config)
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)

# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
ce_value_loss = 0.0
for i in range(self.config.q_ensemble_size):
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config)
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)

q_value_loss = (
(
Expand All @@ -420,7 +421,6 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
self.model.change_q_grad(mode=False)
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])

with torch.no_grad():
Expand All @@ -430,14 +430,9 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
self.scale.update(qs[0])
qs = self.scale(qs)

rho = torch.pow(self.config.temporal_decay_coeff, torch.arange(len(qs), device=qs.device)).unsqueeze(
-1
)

pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=(1, 2))
* rho
# * temporal_loss_coeffs
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
Expand All @@ -447,7 +442,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ self.config.pi_coeff * pi_loss
+ pi_loss
)

info.update(
Expand Down
17 changes: 9 additions & 8 deletions lerobot/common/policies/tdmpc2/tdmpc2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
import pudb

pudb.set_trace()
return -(target * pred).sum(-1, keepdim=True)


Expand Down Expand Up @@ -137,16 +134,20 @@ def symexp(x):

def two_hot(x, cfg):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""

# x shape [horizon, num_features]
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
bin_offset = (x - cfg.vmin) / cfg.bin_size - bin_idx.float()
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
soft_two_hot.scatter_(1, bin_idx, 1 - bin_offset)
soft_two_hot.scatter_(1, (bin_idx + 1) % cfg.num_bins, bin_offset)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
soft_two_hot = torch.zeros(
*x.shape, cfg.num_bins, device=x.device
) # shape [horizon, num_features, num_bins]
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot


Expand Down
12 changes: 12 additions & 0 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None

elif policy.name == "tdmpc2":
params_group = [
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
{"params": policy.model._dynamics.parameters()},
{"params": policy.model._reward.parameters()},
{"params": policy.model._Qs.parameters()},
{"params": policy.model._pi.parameters(), "eps": 1e-5},
]
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
lr_scheduler = None

elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler

Expand Down

0 comments on commit 1449014

Please sign in to comment.