Skip to content

[BUG] MinariExperienceReplay imports non existent function from pytorch/rl #3021

Open
@marcosgalleterobbva

Description

@marcosgalleterobbva

Describe the bug

I am trying to run a MinariExperienceReplay with split trajectories applied, but I am running into a very dumb looking error. Is it possible that this line ->

from torchrl.objectives.utils import split_trajectories

must get substituted by this line?

from torchrl.collectors.utils import split_trajectories

because the function that we are trying to import does not exist.

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[4], [line 147](vscode-notebook-cell:?execution_count=4&line=147)
    [144](vscode-notebook-cell:?execution_count=4&line=144)     print("✅ Training complete!")
    [146](vscode-notebook-cell:?execution_count=4&line=146) if __name__ == "__main__":
--> [147](vscode-notebook-cell:?execution_count=4&line=147)     train_cql_babyai()

Cell In[4], [line 90](vscode-notebook-cell:?execution_count=4&line=90)
     [87](vscode-notebook-cell:?execution_count=4&line=87) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     [89](vscode-notebook-cell:?execution_count=4&line=89) # Load replay buffer
---> [90](vscode-notebook-cell:?execution_count=4&line=90) buffer = MinariExperienceReplay(
     [91](vscode-notebook-cell:?execution_count=4&line=91)     dataset_id=dataset_id,
     [92](vscode-notebook-cell:?execution_count=4&line=92)     batch_size=batch_size,
     [93](vscode-notebook-cell:?execution_count=4&line=93)     download=True,
     [94](vscode-notebook-cell:?execution_count=4&line=94)     split_trajs=True,
     [95](vscode-notebook-cell:?execution_count=4&line=95)     root=f"{os.getenv('HOME')}/.minari/datasets"
     [96](vscode-notebook-cell:?execution_count=4&line=96) )
     [98](vscode-notebook-cell:?execution_count=4&line=98) # Build vocabulary and attach transform
     [99](vscode-notebook-cell:?execution_count=4&line=99) vocab = build_mission_vocab(buffer)

File ~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:189, in MinariExperienceReplay.__init__(self, dataset_id, batch_size, root, download, sampler, writer, collate_fn, pin_memory, prefetch, transform, split_trajs)
    [187](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:187)         except FileNotFoundError:
    [188](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:188)             pass
--> [189](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:189)     storage = self._download_and_preproc()
    [190](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:190) elif self.split_trajs and not os.path.exists(self.data_path):
...
--> [355](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:355)         from torchrl.objectives.utils import split_trajectories
    [357](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:357)         td_data = split_trajectories(td_data).memmap_(self.data_path)
    [358](https://file+.vscode-resource.vscode-cdn.net/Users/O000142/Projects/mercury-rl/mercury/rl/tutorials/algorithms_comparison/torchrl_demos/~/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/data/datasets/minari_data.py:358) with open(self.metadata_path, "w") as metadata_file:

ImportError: cannot import name 'split_trajectories' from 'torchrl.objectives.utils' (/Users/O000142/Projects/mercury-rl/.venv/lib/python3.10/site-packages/torchrl/objectives/utils.py)

To Reproduce

I am trying to create a basic CQL with a discrete action space. Run this code into a notebook and you will get the error

import torch
import torch.nn as nn
from copy import deepcopy
import os
from torchrl.data.datasets.minari_data import MinariExperienceReplay
from torchrl.objectives import DiscreteCQLLoss
from torchrl.modules import QValueActor
from torchrl.data import OneHot
from torchrl.envs import Transform, Compose
from torchrl.objectives.utils import default_value_kwargs
from tensordict import TensorDict


def soft_update(target, source, tau):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            tau * source_param.data + (1.0 - tau) * target_param.data
        )


# === 1. Transform to preprocess mission and normalize image ===
class MissionImageTransform(Transform):
    def __init__(self, mission_vocab):
        super().__init__()
        self.vocab = mission_vocab

    def _call(self, td: TensorDict) -> TensorDict:
        for key in ["observation", ("next", "observation")]:
            subtd = td.get(key, None)
            if subtd is None:
                continue
            if "mission" in subtd:
                missions = subtd["mission"]
                subtd["mission"] = torch.tensor([
                    self.vocab.get(m.decode() if isinstance(m, bytes) else str(m), 0)
                    for m in missions
                ], dtype=torch.long)
            if "image" in subtd:
                subtd["image"] = subtd["image"].float() / 255.0
        return td

# === 2. Build mission vocabulary ===
def build_mission_vocab(replay_buffer, max_batches=100):
    vocab = {}
    idx = 0
    for i, batch in enumerate(replay_buffer):
        missions = batch.get("observation", {}).get("mission", None)
        if missions is None: continue
        for m in missions:
            m = m.decode("utf-8") if isinstance(m, bytes) else str(m)
            if m not in vocab:
                vocab[m] = idx
                idx += 1
        if i >= max_batches:
            break
    return vocab

# === 3. Define Q-network ===
class BabyAIQNetwork(nn.Module):
    def __init__(self, mission_vocab_size, direction_size, num_actions):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        self.mission_emb = nn.Embedding(mission_vocab_size, 32)
        self.direction_emb = nn.Embedding(direction_size, 8)
        self.fc = nn.Sequential(
            nn.Linear(800 + 32 + 8, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions)
        )

    def forward(self, tensordict):
        img = self.cnn(tensordict["observation"]["image"])
        mission = self.mission_emb(tensordict["observation"]["mission"])
        direction = self.direction_emb(tensordict["observation"]["direction"])
        return self.fc(torch.cat([img, mission, direction], dim=-1))

# === 4. Main training logic ===
def train_cql_babyai():
    dataset_id = "minigrid/BabyAI-Pickup/optimal-v0"
    batch_size = 64
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load replay buffer
    buffer = MinariExperienceReplay(
        dataset_id=dataset_id,
        batch_size=batch_size,
        download=True,
        split_trajs=True,
        root=f"{os.getenv('HOME')}/.minari/datasets"
    )

    # Build vocabulary and attach transform
    vocab = build_mission_vocab(buffer)
    buffer.transform = Compose(MissionImageTransform(vocab))

    # Create networks
    q_net = BabyAIQNetwork(len(vocab), direction_size=4, num_actions=7).to(device)
    target_q_net = deepcopy(q_net).to(device)

    # Wrap in actors
    q_actor = QValueActor(q_net, in_keys=["observation"], action_space=OneHot(7))
    target_q_actor = QValueActor(target_q_net, in_keys=["observation"], action_space=OneHot(7))

    # CQL loss setup
    loss_module = DiscreteCQLLoss(q_actor).to(device)
    loss_module.make_value_estimator(
        value_type="td0",
        value_network=target_q_actor,
        gamma=0.99,
        **default_value_kwargs("td0"),
    )

    # Optimizer
    optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-3)
    tau = 0.005
    target_update_freq = 10
    max_steps = 1000

    print("🚀 Training started...")
    for i, batch in enumerate(buffer):
        batch = batch.to(device)
        loss_dict = loss_module(batch)
        total_loss = loss_dict["loss_qvalue"] + loss_dict["loss_cql"]

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if i % target_update_freq == 0:
            soft_update(target_q_net, q_net, tau)

        if i % 10 == 0:
            print(f"Step {i:04d}: Total Loss = {total_loss.item():.4f}")

        if i >= max_steps:
            break

    print("✅ Training complete!")

if __name__ == "__main__":
    train_cql_babyai()

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.7.2 1.26.4 3.10.12 (main, Jul 26 2023, 19:37:41) [Clang 16.0.3 ] darwin

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • [ x] I have read the documentation (required)
  • [ x] I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions