Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a refactored modeling_act for cpu and memory optimization #569

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 107 additions & 41 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def __init__(

self.model = ACT(config)

self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Pre-compute and register expected image keys
self.register_buffer(
"expected_image_keys",
torch.tensor([k.startswith("observation.image") for k in config.input_shapes])
)
# self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]

if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
Expand All @@ -106,9 +111,14 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
self.eval()

batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
if self.expected_image_keys.any():
batch = dict(batch)
keys = [k for k, is_img in zip(self.config.input_shapes.keys(), self.expected_image_keys) if is_img]
batch["observation.images"] = torch.stack([batch[k] for k in keys], dim=-4)

# if len(self.expected_image_keys) > 0:
# batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)

# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
Expand All @@ -134,9 +144,15 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
if self.expected_image_keys.any():
batch = dict(batch)
keys = [k for k, is_img in zip(self.config.input_shapes.keys(), self.expected_image_keys) if is_img]
batch["observation.images"] = torch.stack([batch[k] for k in keys], dim=-4)

# if len(self.expected_image_keys) > 0:
# batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)

batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)

Expand All @@ -151,7 +167,9 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp()))
.sum(-1)
.mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
Expand All @@ -161,7 +179,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
return loss_dict


class ACTTemporalEnsembler:
class ACTTemporalEnsembler(nn.Module):
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.

Expand Down Expand Up @@ -204,9 +222,27 @@ def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
print("online", avg)
```
"""
super().__init__()
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
# TODO: # These lines are redundant since we register them as buffers right after
# self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
# self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)

# Register weights as buffers instead of attributes to improve prefoemence
self.register_buffer(
"ensemble_weights",
torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
)
self.register_buffer(
"ensemble_weights_cumsum",
torch.cumsum(self.ensemble_weights, dim=0)
)
self.register_buffer(
"ones_template",
torch.ones((chunk_size, 1), dtype=torch.long)
)


self.reset()

def reset(self):
Expand All @@ -220,17 +256,20 @@ def update(self, actions: Tensor) -> Tensor:
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)

# TODO: Remove assumimng upgrade of tensores working
# self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
# self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
self.ensembled_actions = actions
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
self.ensembled_actions_count = self.ones_template.to(self.ensembled_actions.device)
# self.ensembled_actions_count = torch.ones(
# (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
# )
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
Expand Down Expand Up @@ -367,6 +406,13 @@ def __init__(self, config: ACTConfig):
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])

self.register_buffer("zero_latent", torch.zeros(1, config.latent_dim))
self.register_buffer("decoder_template", torch.zeros(config.chunk_size, 1, config.dim_model))
self.register_buffer(
"cls_pad_template",
torch.full((1, 2 if self.use_robot_state else 1), False)
)

self._reset_parameters()

def _reset_parameters(self):
Expand Down Expand Up @@ -424,16 +470,19 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso

# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
pos_embed = self.vae_encoder_pos_enc # (1, S+2, D)

# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_robot_state else 1),
False,
device=batch["observation.state"].device,
)

cls_joint_is_pad = self.cls_pad_template.expand(batch_size, -1)

# cls_joint_is_pad = torch.full(
# (batch_size, 2 if self.use_robot_state else 1),
# False,
# device=batch["observation.state"].device,
# )
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
) # (bs, seq+1 or 2)
Expand All @@ -455,9 +504,10 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = self.zero_latent.expand(batch_size, -1)
# latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
# batch["observation.state"].device
# )

# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
Expand All @@ -480,7 +530,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
Expand All @@ -498,11 +548,12 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=encoder_in_pos_embed.dtype,
device=encoder_in_pos_embed.device,
)
decoder_in = self.decoder_template.expand(-1, batch_size, -1)
# decoder_in = torch.zeros(
# (self.config.chunk_size, batch_size, self.config.dim_model),
# dtype=encoder_in_pos_embed.dtype,
# device=encoder_in_pos_embed.device,
# )
decoder_out = self.decoder(
decoder_in,
encoder_out,
Expand Down Expand Up @@ -708,6 +759,20 @@ def __init__(self, dimension: int):
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000

# Register arange buffer to avoid device transfer
self.register_buffer('_pi_tensor', torch.tensor([self._two_pi]))
self.register_buffer('_eps_tensor', torch.tensor([self._eps]))
self.register_buffer('_temp_tensor', torch.tensor([self._temperature]))
self.register_buffer(
"dim_arange",
torch.arange(dimension, dtype=torch.float32)
)
self.register_buffer(
"inverse_frequency",
self._temp_tensor ** (2 * (self.dim_arange // 2) / self.dimension)
)


def forward(self, x: Tensor) -> Tensor:
"""
Args:
Expand All @@ -718,21 +783,22 @@ def forward(self, x: Tensor) -> Tensor:
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
y_range = not_mask.cumsum(1)
x_range = not_mask.cumsum(2)

# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
# are non-zero by construction. This is an artifact of the original code.
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
y_range = y_range / (y_range[:, -1:, :] + self._eps_tensor) * self._pi_tensor
x_range = x_range / (x_range[:, :, -1:] + self._eps_tensor) * self._pi_tensor

inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)

# inverse_frequency = self._temperature ** (
# 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
# )

x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
x_range = x_range.unsqueeze(-1) / self.inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / self.inverse_frequency # (1, H, W, 1)

# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
Expand Down