Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer model to diffusion policy #481

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

amandip7
Copy link
Contributor

@amandip7 amandip7 commented Oct 21, 2024

What this does

Implements transformer model for diffusion policy : https://github.com/orgs/huggingface/projects/46?pane=issue&itemId=64857791
This adds a transformer model class to modelling_diffusion, adds configurations required for this model and sets the default model in diffusion to be transformer instead of U-net.

How it was tested

  • Trained a transformer diffusion policy with the default configurations provided.
  • Ported a pre-trained transformer diffusion policy from original DP repo to lerobot and ran evaluation here.

How to checkout & try? (for the reviewer)

Run the training script:
python lerobot/scripts/train.py
This would train the transformer based diffusion policy with default configuration over pushT task.

Copy link
Contributor

@alexander-soare alexander-soare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @amandip7 for this PR! People have been asking about the transformer version of diffusion policy for a while.

I couple of high-level comments to go with my first review:

  • What do you say that we add a diffusion_transformer.yaml instead of updating the current diffusion.yaml to use the transformer variant?
  • Can you push the weights to the hub so I can test it out with an eval run? Can you please update the PR notes to include instructions on how to run the eval, and what results I should expect? And can you confirm that the success rate is the same as that reported on the paper and/or the same as you get when you run eval on the original DP repo?
  • Have you confirmed that you can train a model on LeRobot and achieve the same success rate as the ported weights? If so, can you pleas push a model to the hub and add instructions on how to evaluate that one?
  • Most of my in-line comments are around tidying up the code a little. I understand you've probably copied the code directly from the DP repo, but we'd like to make the code on LeRobot a little more accessible. This usually means tidying it up a bit by removing redundant code, adding comments, and using more self-explanatory variable names.

Thanks again!

@@ -98,7 +98,7 @@ class DiffusionConfig:

# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
horizon: int = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be reverted right? To keep in line with the default PushT policy.

@@ -134,7 +134,7 @@ class DiffusionConfig:
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
diffusion_step_embed_dim: int = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change should be reverted right? To keep in line with the default PushT policy.

self.config = config

# compute number of tokens for main trunk and condition encoder
if config.n_obs_steps is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.n_obs_steps is None does not seem to be allowed according to the type hinting and documentation. So perhaps it doesn't make sense to handle it here, right?

if config.n_obs_steps is None:
config.n_obs_steps = config.horizon

t = config.horizon
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is it okay if we just leave this as config.horizon rather than binding it to another much less descriptive variable name?

config.n_obs_steps = config.horizon

t = config.horizon
t_cond = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is t_cond? Could you please use a more descriptive variable name or if that's not appropriate, just leave a comment here?

output: (B,T,input_dim)
"""
# 1. time
timesteps = timestep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you assign another variable to the same object, and with such a similar name?

# 1. time
timesteps = timestep
batch_size = sample.shape[0]
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this at the moment.

input_emb = self.input_emb(sample)

# encoder
cond_embeddings = time_emb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind dropping this line and just putting time_emb into torch.cat instead?

x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
# (B,T_cond,n_emb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make comments like this either in-line, or on the line preceding the line of code of concern? I think putting code then comment on the next line is rather unconventional.

position_embeddings = self.cond_pos_emb[:, :tc, :] # each position maps to a (learnable) vector
x = self.drop(cond_embeddings + position_embeddings)
x = self.encoder(x)
memory = x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really have to add another variable into the namespace here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants