diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index bd3692ace1..b2622c2536 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -134,7 +134,6 @@ class DiffusionConfig: down_dims: tuple[int, ...] = (512, 1024, 2048) kernel_size: int = 5 n_groups: int = 8 - diffusion_step_embed_dim: int = 128 use_film_scale_modulation: bool = True # Noise scheduler. noise_scheduler_type: str = "DDPM" @@ -145,6 +144,16 @@ class DiffusionConfig: prediction_type: str = "epsilon" clip_sample: bool = True clip_sample_range: float = 1.0 + # Transformer + use_transformer: bool = False + n_layer: int = 8 + n_head: int = 4 + p_drop_emb: float = 0.0 + p_drop_attn: float = 0.3 + causal_attn: bool = True + n_cond_layers: int = 0 + # Architecture shared params + diffusion_step_embed_dim: int = 128 # Inference num_inference_steps: int | None = None @@ -200,7 +209,7 @@ def __post_init__(self): # Check that the horizon size and U-Net downsampling is compatible. # U-Net downsamples by 2 with each stage. downsampling_factor = 2 ** len(self.down_dims) - if self.horizon % downsampling_factor != 0: + if not self.use_transformer and self.horizon % downsampling_factor != 0: raise ValueError( "The horizon should be an integer multiple of the downsampling factor (which is determined " f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 308a8be3c7..4c41a542a2 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -22,7 +22,7 @@ import math from collections import deque -from typing import Callable +from typing import Callable, Tuple import einops import numpy as np @@ -187,8 +187,12 @@ def __init__(self, config: DiffusionConfig): if "observation.environment_state" in config.input_shapes: self._use_env_state = True global_cond_dim += config.input_shapes["observation.environment_state"][0] - - self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + if config.use_transformer: + self.net = TransformerForDiffusion(config, cond_dim=global_cond_dim) + else: + self.net = DiffusionConditionalUnet1d( + config, global_cond_dim=global_cond_dim * config.n_obs_steps + ) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, @@ -206,6 +210,20 @@ def __init__(self, config: DiffusionConfig): else: self.num_inference_steps = config.num_inference_steps + def get_optimizer( + self, + transformer_weight_decay: float = 1e-3, + rgb_encoder_weight_decay: float = 1e-6, + learning_rate: float = 1e-4, + betas: Tuple[float, float] = [0.9, 0.95], + ) -> torch.optim.Optimizer: + optim_groups = self.net.get_optim_groups(weight_decay=transformer_weight_decay) + optim_groups.append( + {"params": self.rgb_encoder.parameters(), "weight_decay": rgb_encoder_weight_decay} + ) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + # ========= inference ============ def conditional_sample( self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None @@ -225,7 +243,7 @@ def conditional_sample( for t in self.noise_scheduler.timesteps: # Predict model output. - model_output = self.unet( + model_output = self.net( sample, torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), global_cond=global_cond, @@ -324,7 +342,7 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). - pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + pred = self.net(noisy_trajectory, timesteps, global_cond=global_cond) # Compute the loss. # The target is either the original trajectory, or the noise. @@ -749,3 +767,256 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor: out = self.conv2(out) out = out + self.residual_conv(x) return out + + +class TransformerForDiffusion(nn.Module): + def __init__(self, config: DiffusionConfig, cond_dim: int): + super().__init__() + self.config = config + + # conditioning dimension used for positional embeddings + # conditioning over input observation steps (n_obs_steps) + time (1) + t_cond = 1 + config.n_obs_steps + + input_dim = config.output_shapes["action"][0] + # input embedding stem + self.input_emb = nn.Linear(input_dim, config.diffusion_step_embed_dim) + self.pos_emb = nn.Parameter(torch.zeros(1, config.horizon, config.diffusion_step_embed_dim)) + self.drop = nn.Dropout(config.p_drop_emb) + + # cond encoder + self.time_emb = DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim) + + self.cond_obs_emb = nn.Linear(cond_dim, config.diffusion_step_embed_dim) + self.encoder = None + + self.cond_pos_emb = nn.Parameter(torch.zeros(1, t_cond, config.diffusion_step_embed_dim)) + if config.n_cond_layers > 0: + encoder_layer = nn.TransformerEncoderLayer( + d_model=config.diffusion_step_embed_dim, + nhead=config.n_head, + dim_feedforward=4 * config.diffusion_step_embed_dim, + dropout=config.p_drop_attn, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=config.n_cond_layers) + else: + self.encoder = nn.Sequential( + nn.Linear(config.diffusion_step_embed_dim, 4 * config.diffusion_step_embed_dim), + nn.Mish(), + nn.Linear(4 * config.diffusion_step_embed_dim, config.diffusion_step_embed_dim), + ) + # decoder + decoder_layer = nn.TransformerDecoderLayer( + d_model=config.diffusion_step_embed_dim, + nhead=config.n_head, + dim_feedforward=4 * config.diffusion_step_embed_dim, + dropout=config.p_drop_attn, + activation="gelu", + batch_first=True, + norm_first=True, # important for stability + ) + self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=config.n_layer) + + # attention mask + if config.causal_attn: + # causal mask to ensure that attention is only applied to the left in the input sequence + # torch.nn.Transformer uses additive mask as opposed to multiplicative mask in minGPT + # therefore, the upper triangle should be -inf and others (including diag) should be 0. + sz = config.horizon + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) + self.register_buffer("mask", mask) + + # assume conditioning over time and observation both + p, q = torch.meshgrid(torch.arange(config.horizon), torch.arange(t_cond), indexing="ij") + mask = p >= (q - 1) # add one dimension since time is the first token in cond + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) + self.register_buffer("memory_mask", mask) + else: + self.mask = None + self.memory_mask = None + + # decoder head + self.ln_f = nn.LayerNorm(config.diffusion_step_embed_dim) + self.head = nn.Linear(config.diffusion_step_embed_dim, input_dim) + + # constants + self.t_cond = t_cond + self.horizon = config.horizon + self.n_obs_steps = config.n_obs_steps + + # init + self.apply(self._init_weights) + + def _init_weights(self, module): + """ + Initializes weights for different network layers in the module. + - nn.Linear and nn.Embedding: Normal(0, 0.02) for weights, zero for bias. + - nn.MultiheadAttention: Normal(0, 0.02) for projection weights, zero for biases. + - nn.LayerNorm: Ones for weights, zeros for biases. + - Normal(0, 0.02) for positional embeddings module.pos_emb. + - Predefined layers are ignored. + Args: + module (torch.nn.Module): The module to initialize. + """ + ignore_types = ( + nn.Dropout, + DiffusionSinusoidalPosEmb, + nn.TransformerEncoderLayer, + nn.TransformerDecoderLayer, + nn.TransformerEncoder, + nn.TransformerDecoder, + nn.ModuleList, + nn.Mish, + nn.Sequential, + ) + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.MultiheadAttention): + weight_names = ["in_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight"] + for name in weight_names: + weight = getattr(module, name) + if weight is not None: + torch.nn.init.normal_(weight, mean=0.0, std=0.02) + + bias_names = ["in_proj_bias", "bias_k", "bias_v"] + for name in bias_names: + bias = getattr(module, name) + if bias is not None: + torch.nn.init.zeros_(bias) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + elif isinstance(module, TransformerForDiffusion): + torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) + if module.cond_obs_emb is not None: + torch.nn.init.normal_(module.cond_pos_emb, mean=0.0, std=0.02) + elif isinstance(module, ignore_types): + # no param + pass + else: + raise RuntimeError("Unaccounted module {}".format(module)) + + def get_optim_groups(self, weight_decay: float = 1e-3): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, torch.nn.MultiheadAttention) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _ in m.named_parameters(): + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name + + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.startswith("bias"): + # MultiheadAttention bias starts with "bias" + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add("pos_emb") + # no_decay.add("_dummy_variable") + if self.cond_pos_emb is not None: + no_decay.add("cond_pos_emb") + + # validate that we considered every parameter + # param_dict = {pn: p for pn, p in self.named_parameters()} + param_dict = dict(self.named_parameters()) + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(decay)], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(no_decay)], + "weight_decay": 0.0, + }, + ] + return optim_groups + + def configure_optimizers( + self, + learning_rate: float = 1e-4, + weight_decay: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.95), + ): + optim_groups = self.get_optim_groups(weight_decay=weight_decay) + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) + return optimizer + + def forward(self, sample: torch.Tensor, timestep: torch.Tensor, global_cond: torch.Tensor, **kwargs): + """ + Args: + sample: (B, T, input_dim) tensor for input to the decoder after embedding. + timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). + global_cond: (B, global_cond_dim) + output: (B, T, input_dim) + Returns: + (B, T, input_dim) diffusion model prediction. + """ + # 1. time + batch_size = sample.shape[0] + time_emb = self.time_emb(timestep).unsqueeze(1) # (B,1,n_emb) + + cond = einops.rearrange( + global_cond, "b (s n) ... -> b s (n ...)", b=batch_size, s=self.n_obs_steps + ) # (B,To,n_cond) + + # process input + input_emb = self.input_emb(sample) + + # encoder + cond_obs_emb = self.cond_obs_emb(cond) # (B,To,n_emb) + cond_embeddings = torch.cat([time_emb, cond_obs_emb], dim=1) # (B,To + 1,n_emb) + + position_embeddings = self.cond_pos_emb[ + :, : cond_embeddings.shape[1], : + ] # each position maps to a (learnable) vector + memory = self.drop(cond_embeddings + position_embeddings) + memory = self.encoder(memory) # (B,T_cond,n_emb) + + # decoder + position_embeddings = self.pos_emb[ + :, : input_emb.shape[1], : + ] # each position maps to a (learnable) vector + x = self.drop(input_emb + position_embeddings) # (B,T,n_emb) + x = self.decoder( + tgt=x, memory=memory, tgt_mask=self.mask, memory_mask=self.memory_mask + ) # (B,T,n_emb) + + # head + x = self.ln_f(x) + x = self.head(x) # (B,T,n_inp) + + return x diff --git a/lerobot/configs/policy/diffusion_transformer.yaml b/lerobot/configs/policy/diffusion_transformer.yaml new file mode 100644 index 0000000000..e56ced8ce0 --- /dev/null +++ b/lerobot/configs/policy/diffusion_transformer.yaml @@ -0,0 +1,109 @@ +# @package _global_ + +# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy. +# Note: We do not track EMA model weights as we discovered it does not improve the results. See +# https://github.com/huggingface/lerobot/pull/134 for more details. + +seed: 100000 +dataset_repo_id: lerobot/pusht + +override_dataset_stats: + # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model? + observation.image: + mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model + # from the original codebase, but we should remove these and train our own pretrained model + observation.state: + min: [13.456424, 32.938293] + max: [496.14618, 510.9579] + action: + min: [12.0, 25.0] + max: [511.0, 511.0] + +training: + offline_steps: 200000 + online_steps: 0 + eval_freq: 25000 + save_freq: 25000 + save_checkpoint: true + + batch_size: 64 + grad_clip_norm: 10 + lr: 1.0e-4 + lr_scheduler: cosine + lr_warmup_steps: 1000 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + online_steps_between_rollouts: 1 + + transformer_weight_decay: 1e-3 + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1 + +eval: + n_episodes: 50 + batch_size: 50 + +policy: + name: diffusion + + # Input / output structure. + n_obs_steps: 2 + horizon: 10 + n_action_steps: 8 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.image: [3, 96, 96] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.image: mean_std + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + crop_shape: [84, 84] + crop_is_random: True + pretrained_backbone_weights: null + use_group_norm: True + spatial_softmax_num_keypoints: 32 + # Transformer + use_transformer: True + n_layer: 8 + n_head: 4 + p_drop_emb: 0.0 + p_drop_attn: 0.3 + causal_attn: True + n_cond_layers: 0 + diffusion_step_embed_dim: 256 + # Noise scheduler. + noise_scheduler_type: DDPM + num_train_timesteps: 100 + beta_schedule: squaredcos_cap_v2 + beta_start: 0.0001 + beta_end: 0.02 + prediction_type: epsilon # epsilon / sample + clip_sample: True + clip_sample_range: 1.0 + + # Inference + num_inference_steps: null # if not provided, defaults to `num_train_timesteps` + + # Loss computation + do_mask_loss_for_padding: false diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f60f904eb2..1a286199c2 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -75,13 +75,21 @@ def make_optimizer_and_scheduler(cfg, policy): ) lr_scheduler = None elif cfg.policy.name == "diffusion": - optimizer = torch.optim.Adam( - policy.diffusion.parameters(), - cfg.training.lr, - cfg.training.adam_betas, - cfg.training.adam_eps, - cfg.training.adam_weight_decay, - ) + if "use_transformer" in cfg.policy: + optimizer = policy.diffusion.get_optimizer( + transformer_weight_decay=cfg.training.transformer_weight_decay, + rgb_encoder_weight_decay=cfg.training.adam_weight_decay, + learning_rate=cfg.training.lr, + betas=cfg.training.adam_betas, + ) + else: + optimizer = torch.optim.Adam( + policy.diffusion.parameters(), + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + cfg.training.adam_weight_decay, + ) from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler(