Open
Description
Describe the bug
I am trying to run a MinariExperienceReplay with split trajectories applied, but I am running into a very dumb looking error. Is it possible that this line ->
rl/torchrl/data/datasets/minari_data.py
Line 353 in 350fa1d
must get substituted by this line?
from torchrl.collectors.utils import split_trajectories
because the function that we are trying to import does not exist.
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[4], [line 147](vscode-notebook-cell:?execution_count=4&line=147)
[144](vscode-notebook-cell:?execution_count=4&line=144) print("✅ Training complete!")
[146](vscode-notebook-cell:?execution_count=4&line=146) if __name__ == "__main__":
--> [147](vscode-notebook-cell:?execution_count=4&line=147) train_cql_babyai()
Cell In[4], [line 90](vscode-notebook-cell:?execution_count=4&line=90)
[87](vscode-notebook-cell:?execution_count=4&line=87) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[89](vscode-notebook-cell:?execution_count=4&line=89) # Load replay buffer
---> [90](vscode-notebook-cell:?execution_count=4&line=90) buffer = MinariExperienceReplay(
[91](vscode-notebook-cell:?execution_count=4&line=91) dataset_id=dataset_id,
[92](vscode-notebook-cell:?execution_count=4&line=92) batch_size=batch_size,
[93](vscode-notebook-cell:?execution_count=4&line=93) download=True,
[94](vscode-notebook-cell:?execution_count=4&line=94) split_trajs=True,
[95](vscode-notebook-cell:?execution_count=4&line=95) root=f"{os.getenv('HOME')}/.minari/datasets"
[96](vscode-notebook-cell:?execution_count=4&line=96) )
[98](vscode-notebook-cell:?execution_count=4&line=98) # Build vocabulary and attach transform
[99](vscode-notebook-cell:?execution_count=4&line=99) vocab = build_mission_vocab(buffer)
File ~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:189, in MinariExperienceReplay.__init__(self, dataset_id, batch_size, root, download, sampler, writer, collate_fn, pin_memory, prefetch, transform, split_trajs)
[187](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:187) except FileNotFoundError:
[188](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:188) pass
--> [189](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:189) storage = self._download_and_preproc()
[190](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:190) elif self.split_trajs and not os.path.exists(self.data_path):
...
--> [355](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:355) from torchrl.objectives.utils import split_trajectories
[357](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:357) td_data = split_trajectories(td_data).memmap_(self.data_path)
[358](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:358) with open(self.metadata_path, "w") as metadata_file:
ImportError: cannot import name 'split_trajectories' from 'torchrl.objectives.utils' (/Users/O000142/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/objectives/utils.py)
To Reproduce
I am trying to create a basic CQL with a discrete action space. Run this code into a notebook and you will get the error
import torch
import torch.nn as nn
from copy import deepcopy
import os
from torchrl.data.datasets.minari_data import MinariExperienceReplay
from torchrl.objectives import DiscreteCQLLoss
from torchrl.modules import QValueActor
from torchrl.data import OneHot
from torchrl.envs import Transform, Compose
from torchrl.objectives.utils import default_value_kwargs
from tensordict import TensorDict
def soft_update(target, source, tau):
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
tau * source_param.data + (1.0 - tau) * target_param.data
)
# === 1. Transform to preprocess mission and normalize image ===
class MissionImageTransform(Transform):
def __init__(self, mission_vocab):
super().__init__()
self.vocab = mission_vocab
def _call(self, td: TensorDict) -> TensorDict:
for key in ["observation", ("next", "observation")]:
subtd = td.get(key, None)
if subtd is None:
continue
if "mission" in subtd:
missions = subtd["mission"]
subtd["mission"] = torch.tensor([
self.vocab.get(m.decode() if isinstance(m, bytes) else str(m), 0)
for m in missions
], dtype=torch.long)
if "image" in subtd:
subtd["image"] = subtd["image"].float() / 255.0
return td
# === 2. Build mission vocabulary ===
def build_mission_vocab(replay_buffer, max_batches=100):
vocab = {}
idx = 0
for i, batch in enumerate(replay_buffer):
missions = batch.get("observation", {}).get("mission", None)
if missions is None: continue
for m in missions:
m = m.decode("utf-8") if isinstance(m, bytes) else str(m)
if m not in vocab:
vocab[m] = idx
idx += 1
if i >= max_batches:
break
return vocab
# === 3. Define Q-network ===
class BabyAIQNetwork(nn.Module):
def __init__(self, mission_vocab_size, direction_size, num_actions):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten()
)
self.mission_emb = nn.Embedding(mission_vocab_size, 32)
self.direction_emb = nn.Embedding(direction_size, 8)
self.fc = nn.Sequential(
nn.Linear(800 + 32 + 8, 256),
nn.ReLU(),
nn.Linear(256, num_actions)
)
def forward(self, tensordict):
img = self.cnn(tensordict["observation"]["image"])
mission = self.mission_emb(tensordict["observation"]["mission"])
direction = self.direction_emb(tensordict["observation"]["direction"])
return self.fc(torch.cat([img, mission, direction], dim=-1))
# === 4. Main training logic ===
def train_cql_babyai():
dataset_id = "minigrid/BabyAI-Pickup/optimal-v0"
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load replay buffer
buffer = MinariExperienceReplay(
dataset_id=dataset_id,
batch_size=batch_size,
download=True,
split_trajs=True,
root=f"{os.getenv('HOME')}/.minari/datasets"
)
# Build vocabulary and attach transform
vocab = build_mission_vocab(buffer)
buffer.transform = Compose(MissionImageTransform(vocab))
# Create networks
q_net = BabyAIQNetwork(len(vocab), direction_size=4, num_actions=7).to(device)
target_q_net = deepcopy(q_net).to(device)
# Wrap in actors
q_actor = QValueActor(q_net, in_keys=["observation"], action_space=OneHot(7))
target_q_actor = QValueActor(target_q_net, in_keys=["observation"], action_space=OneHot(7))
# CQL loss setup
loss_module = DiscreteCQLLoss(q_actor).to(device)
loss_module.make_value_estimator(
value_type="td0",
value_network=target_q_actor,
gamma=0.99,
**default_value_kwargs("td0"),
)
# Optimizer
optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-3)
tau = 0.005
target_update_freq = 10
max_steps = 1000
print("🚀 Training started...")
for i, batch in enumerate(buffer):
batch = batch.to(device)
loss_dict = loss_module(batch)
total_loss = loss_dict["loss_qvalue"] + loss_dict["loss_cql"]
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i % target_update_freq == 0:
soft_update(target_q_net, q_net, tau)
if i % 10 == 0:
print(f"Step {i:04d}: Total Loss = {total_loss.item():.4f}")
if i >= max_steps:
break
print("✅ Training complete!")
if __name__ == "__main__":
train_cql_babyai()
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.2 1.26.4 3.10.12 (main, Jul 26 2023, 19:37:41) [Clang 16.0.3 ] darwin
Checklist
- [x ] I have checked that there is no similar issue in the repo (required)
- [ x] I have read the documentation (required)
- [ x] I have provided a minimal working example to reproduce the bug (required)