Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer model to diffusion policy #481

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions lerobot/common/policies/diffusion/configuration_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DiffusionConfig:

# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
horizon: int = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be reverted right? To keep in line with the default PushT policy.

n_action_steps: int = 8

input_shapes: dict[str, list[int]] = field(
Expand Down Expand Up @@ -134,7 +134,7 @@ class DiffusionConfig:
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
diffusion_step_embed_dim: int = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be reverted right? To keep in line with the default PushT policy.

use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
Expand All @@ -145,6 +145,14 @@ class DiffusionConfig:
prediction_type: str = "epsilon"
clip_sample: bool = True
clip_sample_range: float = 1.0
# Transformer
use_transformer: bool = True
n_layer: int = 8
n_head: int = 4
p_drop_emb: float = 0.0
p_drop_attn: float = 0.3
causal_attn: bool = True
n_cond_layers: int = 0

# Inference
num_inference_steps: int | None = None
Expand Down Expand Up @@ -200,7 +208,7 @@ def __post_init__(self):
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
if not self.use_transformer and self.horizon % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
Expand Down
288 changes: 284 additions & 4 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import math
from collections import deque
from typing import Callable
from typing import Callable, Tuple

import einops
import numpy as np
Expand Down Expand Up @@ -188,7 +188,12 @@ def __init__(self, config: DiffusionConfig):
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]

self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.use_transformer:
self.net = TransformerForDiffusion(config, cond_dim=global_cond_dim)
else:
self.net = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)

self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
Expand All @@ -206,6 +211,20 @@ def __init__(self, config: DiffusionConfig):
else:
self.num_inference_steps = config.num_inference_steps

def get_optimizer(
self,
transformer_weight_decay: float = 1e-3,
rgb_encoder_weight_decay: float = 1e-6,
learning_rate: float = 1e-4,
betas: Tuple[float, float] = [0.9, 0.95],
) -> torch.optim.Optimizer:
optim_groups = self.net.get_optim_groups(weight_decay=transformer_weight_decay)
optim_groups.append(
{"params": self.rgb_encoder.parameters(), "weight_decay": rgb_encoder_weight_decay}
)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer

# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
Expand All @@ -225,7 +244,7 @@ def conditional_sample(

for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
model_output = self.net(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
Expand Down Expand Up @@ -324,7 +343,7 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)

# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
pred = self.net(noisy_trajectory, timesteps, global_cond=global_cond)

# Compute the loss.
# The target is either the original trajectory, or the noise.
Expand Down Expand Up @@ -749,3 +768,264 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
out = self.conv2(out)
out = out + self.residual_conv(x)
return out


class TransformerForDiffusion(nn.Module):
def __init__(self, config: DiffusionConfig, cond_dim: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I please ask for either more descriptive variable names or comments describing what the variables mean here? I've highlighted at least 1 or 2 specific asks below, but I realized it might be better to make this general request. Please check ACT for inspiration

super().__init__()
self.config = config

# compute number of tokens for main trunk and condition encoder
if config.n_obs_steps is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.n_obs_steps is None does not seem to be allowed according to the type hinting and documentation. So perhaps it doesn't make sense to handle it here, right?

config.n_obs_steps = config.horizon

t = config.horizon
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is it okay if we just leave this as config.horizon rather than binding it to another much less descriptive variable name?

t_cond = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is t_cond? Could you please use a more descriptive variable name or if that's not appropriate, just leave a comment here?

t_cond += config.n_obs_steps

input_dim = config.output_shapes["action"][0]
# input embedding stem
self.input_emb = nn.Linear(input_dim, config.diffusion_step_embed_dim)
self.pos_emb = nn.Parameter(torch.zeros(1, t, config.diffusion_step_embed_dim))
self.drop = nn.Dropout(config.p_drop_emb)

# cond encoder
self.time_emb = DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

diffusion_step_embed_dim is found under the # Unet part of the config. Could you please consider making a # Architecture shared params section as well as # Unet and # Transformer? (or whatever else you think is suitable to make the configuration params clearer.

self.cond_obs_emb = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line appears to be redundant.


self.cond_obs_emb = nn.Linear(cond_dim, config.diffusion_step_embed_dim)

self.cond_pos_emb = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line appears to be redundant.

self.encoder = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines also.

self.decoder = None

self.cond_pos_emb = nn.Parameter(torch.zeros(1, t_cond, config.diffusion_step_embed_dim))
if config.n_cond_layers > 0:
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.diffusion_step_embed_dim,
nhead=config.n_head,
dim_feedforward=4 * config.diffusion_step_embed_dim,
dropout=config.p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=config.n_cond_layers)
else:
self.encoder = nn.Sequential(
nn.Linear(config.diffusion_step_embed_dim, 4 * config.diffusion_step_embed_dim),
nn.Mish(),
nn.Linear(4 * config.diffusion_step_embed_dim, config.diffusion_step_embed_dim),
)
# decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.diffusion_step_embed_dim,
nhead=config.n_head,
dim_feedforward=4 * config.diffusion_step_embed_dim,
dropout=config.p_drop_attn,
activation="gelu",
batch_first=True,
norm_first=True, # important for stability
)
self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=config.n_layer)

# attention mask
if config.causal_attn:
# causal mask to ensure that attention is only applied to the left in the input sequence
# torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT
# therefore, the upper triangle should be -inf and others (including diag) should be 0.
sz = t
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
self.register_buffer("mask", mask)

# assume conditioning over time and observation both
p, q = torch.meshgrid(torch.arange(t), torch.arange(t_cond), indexing="ij")
mask = p >= (q - 1) # add one dimension since time is the first token in cond
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
self.register_buffer("memory_mask", mask)
else:
self.mask = None
self.memory_mask = None

# decoder head
self.ln_f = nn.LayerNorm(config.diffusion_step_embed_dim)
self.head = nn.Linear(config.diffusion_step_embed_dim, input_dim)

# constants
self.t = t
self.t_cond = t_cond
self.horizon = config.horizon
self.n_obs_steps = config.n_obs_steps

# init
self.apply(self._init_weights)
# logger.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please remove this commented code?

# "number of parameters: %e", sum(p.numel() for p in self.parameters())
# )

def _init_weights(self, module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything we can say in a docstring to summarize the weight initialization strategy used here?

ignore_types = (
nn.Dropout,
DiffusionSinusoidalPosEmb,
nn.TransformerEncoderLayer,
nn.TransformerDecoderLayer,
nn.TransformerEncoder,
nn.TransformerDecoder,
nn.ModuleList,
nn.Mish,
nn.Sequential,
)
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
weight_names = ["in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"]
for name in weight_names:
weight = getattr(module, name)
if weight is not None:
torch.nn.init.normal_(weight, mean=0.0, std=0.02)

bias_names = ["in_proj_bias", "bias_k", "bias_v"]
for name in bias_names:
bias = getattr(module, name)
if bias is not None:
torch.nn.init.zeros_(bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, TransformerForDiffusion):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
if module.cond_obs_emb is not None:
torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02)
elif isinstance(module, ignore_types):
# no param
pass
else:
raise RuntimeError("Unaccounted module {}".format(module))

def get_optim_groups(self, weight_decay: float = 1e-3):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""

# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention)
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, _ in m.named_parameters():
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name

if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.startswith("bias"):
# MultiheadAttention bias starts with "bias"
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)

# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add("pos_emb")
# no_decay.add("_dummy_variable")
if self.cond_pos_emb is not None:
no_decay.add("cond_pos_emb")

# validate that we considered every parameter
# param_dict = {pn: p for pn, p in self.named_parameters()}
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert (
len(param_dict.keys() - union_params) == 0
), "parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)

# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(decay)],
"weight_decay": weight_decay,
},
{
"params": [param_dict[pn] for pn in sorted(no_decay)],
"weight_decay": 0.0,
},
]
return optim_groups

def configure_optimizers(
self,
learning_rate: float = 1e-4,
weight_decay: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
):
optim_groups = self.get_optim_groups(weight_decay=weight_decay)
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer

def forward(self, sample: torch.Tensor, timestep: torch.Tensor, global_cond: torch.Tensor, **kwargs):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we please tidy up this docstring?

  • make sure documented variable names match the function signature
  • Explain what the the variables are where not obvious, expecially timestamp (see DiffusionConditionalUnet1d.forward.

x: (B,T,input_dim)
timestep: (B,)
global_cond: (B, global_cond_dim)
output: (B,T,input_dim)
"""
# 1. time
timesteps = timestep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you assign another variable to the same object, and with such a similar name?

batch_size = sample.shape[0]
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this at the moment.

timesteps = timesteps.expand(batch_size)
time_emb = self.time_emb(timesteps).unsqueeze(1)
# (B,1,n_emb)

cond = einops.rearrange(global_cond, "b (s n) ... -> b s (n ...)", b=batch_size, s=self.n_obs_steps)
# (B,To,n_cond)

# process input
input_emb = self.input_emb(sample)

# encoder
cond_embeddings = time_emb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind dropping this line and just putting time_emb into torch.cat instead?

# (B,1,n_emb)

cond_obs_emb = self.cond_obs_emb(cond)
# (B,To,n_emb)
cond_embeddings = torch.cat([cond_embeddings, cond_obs_emb], dim=1)
# (B,To + 1,n_emb)

tc = cond_embeddings.shape[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tc is used only one and on the next line. In this instance I think it would make more sense to just use the RHS of this assignment in the line below, rather than adding an obscure variable name into the namespace.

position_embeddings = self.cond_pos_emb[:, :tc, :] # each position maps to a (learnable) vector
x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really have to add another variable into the namespace here?

# (B,T_cond,n_emb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make comments like this either in-line, or on the line preceding the line of code of concern? I think putting code then comment on the next line is rather unconventional.


# decoder
token_embeddings = input_emb
t = token_embeddings.shape[1]
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
# (B,T,n_emb)
x = self.decoder(tgt=x, memory=memory, tgt_mask=self.mask, memory_mask=self.memory_mask)
# (B,T,n_emb)

# head
x = self.ln_f(x)
x = self.head(x)
# (B,T,n_inp)
return x
Loading