From 63b79ba752f8b1dd4da0eb1e23d07e290c5d5136 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Mon, 7 Jul 2025 18:25:47 +0200 Subject: [PATCH 1/8] [Feature] MinariExperienceReplay now can load text data as tensors --- setup.py | 3 +- torchrl/data/datasets/minari_data.py | 232 ++++++++++++------ .../sphinx-tutorials/minari_data_loading.py | 43 ++++ 3 files changed, 197 insertions(+), 81 deletions(-) create mode 100644 tutorials/sphinx-tutorials/minari_data_loading.py diff --git a/setup.py b/setup.py index 48dc7214cff..35f978e7435 100644 --- a/setup.py +++ b/setup.py @@ -251,8 +251,7 @@ def _main(argv): "einops", # For tensor operations "safetensors", # For model loading ], - } - extra_requires["all"] = set() + "all": set()} for key in list(extra_requires.keys()): extra_requires["all"] = extra_requires["all"].union(extra_requires[key]) extra_requires["all"] = sorted(extra_requires["all"]) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 158449afca4..db9af3cd348 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -45,6 +45,12 @@ "int64": torch.int64, "int32": torch.int32, "uint8": torch.uint8, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.int64": torch.int64, + "torch.int32": torch.int32, + "torch.uint8": torch.uint8 } @@ -88,25 +94,24 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): it is assumed that any ``truncated`` or ``terminated`` signal is equivalent to the end of a trajectory. Defaults to ``False``. + string_to_tensor_map (dict[str, Callable[[str], Tensor]]): Optional mapping from a string key path + (e.g., 'observations/mission') to a function that converts unsupported NonTensorData to a tensor. + This allows customization of how string-based fields (e.g., language missions) + are encoded during dataset writing. Attributes: available_datasets: a list of accepted entries to be downloaded. - .. note:: - Text data is currenrtly discarded from the wrapped dataset, as there is not - PyTorch native way of representing text data. - If this feature is required, please post an issue on TorchRL's GitHub - repository. - Examples: >>> from torchrl.data.datasets.minari_data import MinariExperienceReplay - >>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force") + >>> data = MinariExperienceReplay("D4RL/door/human-v2", batch_size=32, download="force") >>> for sample in data: ... torchrl_logger.info(sample) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([32, 28]), device=cpu, dtype=torch.float32, is_shared=False), + episode: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False), info: TensorDict( fields={ @@ -125,28 +130,12 @@ class MinariExperienceReplay(BaseDatasetExperienceReplay): is_shared=False), observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False), - state: TensorDict( - fields={ - door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False), - qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), - qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, - batch_size=torch.Size([32]), - device=cpu, - is_shared=False), terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False), - observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False), - state: TensorDict( - fields={ - door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False), - qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False), - qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)}, - batch_size=torch.Size([32]), - device=cpu, - is_shared=False)}, + observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False)}, batch_size=torch.Size([32]), device=cpu, is_shared=False) @@ -159,7 +148,7 @@ def __init__( batch_size: int, *, root: str | Path | None = None, - download: bool = True, + download: bool | str = True, sampler: Sampler | None = None, writer: Writer | None = None, collate_fn: Callable | None = None, @@ -167,6 +156,7 @@ def __init__( prefetch: int | None = None, transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, + string_to_tensor_map: dict[str, Callable[[any], torch.Tensor]] | None = None ): self.dataset_id = dataset_id if root is None: @@ -175,6 +165,7 @@ def __init__( self.root = root self.split_trajs = split_trajs self.download = download + self._string_to_tensor_map = string_to_tensor_map or {} if self.download == "force" or (self.download and not self._is_downloaded()): if self.download == "force": try: @@ -243,49 +234,64 @@ def _download_and_preproc(self): minari.download_dataset(dataset_id=self.dataset_id) parent_dir = Path(tmpdir) / self.dataset_id / "data" - td_data = TensorDict() + td_data = TensorDict({}, batch_size=[]) total_steps = 0 torchrl_logger.info("first read through data to create data structure...") - h5_data = PersistentTensorDict.from_h5(parent_dir / "main_data.hdf5") + h5_data = PersistentTensorDict.from_h5(str(parent_dir / "main_data.hdf5")) # populate the tensordict episode_dict = {} - for i, (episode_key, episode) in enumerate(h5_data.items()): - episode_num = int(episode_key[len("episode_") :]) + for episode_key, episode in h5_data.items(): + episode_num = int(episode_key[len("episode_"):]) episode_len = episode["actions"].shape[0] episode_dict[episode_num] = (episode_key, episode_len) # Get the total number of steps for the dataset total_steps += episode_len - if i == 0: - td_data.set("episode", 0) - for key, val in episode.items(): - match = _NAME_MATCH[key] - if key in ("observations", "state", "infos"): - if ( - not val.shape - ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: - if val.is_empty(): - continue - val = _patch_info(val) - td_data.set(("next", match), torch.zeros_like(val[0])) - td_data.set(match, torch.zeros_like(val[0])) - if key not in ("terminations", "truncations", "rewards"): - td_data.set(match, torch.zeros_like(val[0])) + + # Use first episode to allocate structure + ref_episode = h5_data.get(episode_dict[0][0]) + + td_data.set("episode", torch.zeros((total_steps,), dtype=torch.int64)) + + field_max_tracker = {} + + for key, val in ref_episode.items(): + match = _NAME_MATCH[key] + + if key == "observations": + for subkey, subval in val.items(): + path_key = f"{key}/{subkey}" + if path_key in self._string_to_tensor_map: + encoded = self._string_to_tensor_map[path_key](subval.data[0]) + shape = (total_steps, *encoded.shape) + td_data.set(("observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) + td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) else: - td_data.set( - ("next", match), - torch.zeros_like(val[0].unsqueeze(-1)), - ) + shape = (total_steps,) if subval.dim() == 1 else (total_steps, *subval[0].shape) + td_data.set(("observation", subkey), torch.zeros(shape, dtype=subval.dtype)) + td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=subval.dtype)) + + elif key in ("terminations", "truncations", "rewards"): + td_data.set(("next", match), torch.zeros((total_steps, 1), dtype=val.dtype)) + + else: + shape = (total_steps,) if val.dim() == 1 else (total_steps, *val[0].shape) + td_data.set((match,), torch.zeros(shape, dtype=val.dtype)) + if key in ("state", "infos"): + td_data.set(("next", match), torch.zeros(shape, dtype=val.dtype)) - # give it the proper size + # Set batch size + td_data.batch_size = [total_steps] + + # Set 'done' placeholder td_data["next", "done"] = ( - td_data["next", "truncated"] | td_data["next", "terminated"] + td_data["next", "truncated"] | td_data["next", "terminated"] ) if "terminated" in td_data.keys(): td_data["done"] = td_data["truncated"] | td_data["terminated"] - td_data = td_data.expand(total_steps) + # save to designated location torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ") - td_data = td_data.memmap_like(self.data_path_root) + td_data = td_data.memmap_like(str(self.data_path_root)) torchrl_logger.info(f"tensordict structure: {td_data}") torchrl_logger.info(f"Reading data from {max(*episode_dict) + 1} episodes") @@ -300,11 +306,31 @@ def _download_and_preproc(self): data_view.fill_("episode", episode_num) for key, val in episode.items(): match = _NAME_MATCH[key] - if key in ( - "observations", - "state", - "infos", - ): + if key == "observations": + for subkey, subval in val.items(): + path_key = f"{key}/{subkey}" + if path_key in self._string_to_tensor_map: + encoder = self._string_to_tensor_map[path_key] + try: + encoded = [encoder(s) for s in subval.data[:-1]] + next_encoded = [encoder(s) for s in subval.data[1:]] + except Exception as e: + raise RuntimeError( + f"Failed to encode string at {path_key}: {e}" + ) from e + combined = torch.stack(encoded + next_encoded).flatten() + if combined.dtype in (torch.int64, torch.int32, torch.uint8): + max_val = combined.max().item() + field_max_tracker[path_key] = max( + max_val, field_max_tracker.get(path_key, 0) + ) + data_view["observation", subkey].copy_(torch.stack(encoded)) + data_view["next", "observation", subkey].copy_(torch.stack(next_encoded)) + else: + data_view["observation", subkey].copy_(subval[:-1]) + data_view["next", "observation", subkey].copy_(subval[1:]) + + elif key in ("state", "infos"): if not val.shape or steps != val.shape[0] - 1: if val.is_empty(): continue @@ -315,24 +341,21 @@ def _download_and_preproc(self): ) data_view["next", match].copy_(val[1:]) data_view[match].copy_(val[:-1]) - elif key not in ("terminations", "truncations", "rewards"): - if steps is None: - steps = val.shape[0] - else: - if steps != val.shape[0]: - raise RuntimeError( - f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." - ) - data_view[match].copy_(val) + + elif key in ("terminations", "truncations", "rewards"): + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view["next", match].copy_(val.unsqueeze(-1)) + else: - if steps is None: - steps = val.shape[0] - else: - if steps != val.shape[0]: - raise RuntimeError( - f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." - ) - data_view[("next", match)].copy_(val.unsqueeze(-1)) + if steps != val.shape[0]: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}." + ) + data_view[match].copy_(val) + data_view["next", "done"].copy_( data_view["next", "terminated"] | data_view["next", "truncated"] ) @@ -346,19 +369,48 @@ def _download_and_preproc(self): f"index={index} - episode num {episode_num}" ) index += steps + + # Extract real encoder input examples before closing h5_data + sampled_values = {} + for path_key in self._string_to_tensor_map: + if path_key.startswith("observations/"): + field = path_key.split("/", 1)[1] + try: + # Use the first episode + episode_key = episode_dict[0][0] + subval = h5_data.get(episode_key)["observations"][field] + sampled_values[path_key] = subval[0] # eager copy + except Exception as e: + torchrl_logger.warning(f"Could not extract value for {path_key}: {e}") + h5_data.close() # Add a "done" entry if self.split_trajs: with td_data.unlock_(): from torchrl.collectors.utils import split_trajectories + td_data = split_trajectories(td_data).memmap_(str(self.data_path)) - td_data = split_trajectories(td_data).memmap_(self.data_path) with open(self.metadata_path, "w") as metadata_file: dataset = minari.load_dataset(self.dataset_id) self.metadata = asdict(dataset.spec) - self.metadata["observation_space"] = _spec_to_dict( - self.metadata["observation_space"] - ) + + obs_space = _spec_to_dict(self.metadata["observation_space"]) + if "subspaces" in obs_space: + for path_key, encoder in self._string_to_tensor_map.items(): + if path_key.startswith("observations/"): + field = path_key.split("/", 1)[1] + try: + example_value = sampled_values[path_key] + example_tensor = encoder(example_value) + max_val = field_max_tracker.get(path_key, None) + obs_space["subspaces"][field] = self._tensor_to_spec_dict(example_tensor, + max_value=max_val) + except Exception as e: + torchrl_logger.warning( + f"Could not encode observation field '{field}' for metadata: {e}" + ) + + self.metadata["observation_space"] = obs_space self.metadata["action_space"] = _spec_to_dict( self.metadata["action_space"] ) @@ -366,12 +418,34 @@ def _download_and_preproc(self): self._load_and_proc_metadata() return td_data + @staticmethod + def _tensor_to_spec_dict(tensor: torch.Tensor, max_value: int | None = None) -> dict: + if tensor.dtype in (torch.float32, torch.float64, torch.float16): + return { + "type": "Box", + "low": [-float("inf")] * tensor.numel(), + "high": [float("inf")] * tensor.numel(), + "dtype": str(tensor.dtype), + "shape": list(tensor.shape), + } + elif tensor.dtype in (torch.uint8, torch.int64, torch.int32): + if max_value is None: + raise ValueError("max_value must be provided for Discrete tensors.") + return { + "type": "Discrete", + "dtype": str(tensor.dtype), + "n": int(max_value) + 1, + "shape": list(tensor.shape), + } + else: + raise TypeError(f"Unsupported dtype {tensor.dtype} for spec conversion.") + def _make_split(self): from torchrl.collectors.utils import split_trajectories self._load_and_proc_metadata() td_data = TensorDict.load_memmap(self.data_path_root) - td_data = split_trajectories(td_data).memmap_(self.data_path) + td_data = split_trajectories(td_data).memmap_(str(self.data_path)) return td_data def _load(self): diff --git a/tutorials/sphinx-tutorials/minari_data_loading.py b/tutorials/sphinx-tutorials/minari_data_loading.py new file mode 100644 index 00000000000..60c8ea60aa9 --- /dev/null +++ b/tutorials/sphinx-tutorials/minari_data_loading.py @@ -0,0 +1,43 @@ +import torch +from torchrl.data.datasets.minari_data import MinariExperienceReplay +from torchrl._utils import logger as torchrl_logger + +# Define the possible missions +COLORS = ["red", "green", "blue", "purple", "yellow", "grey"] +OBJECT_TYPES = ["box", "ball", "key"] +MISSION_TO_IDX = { + f"pick up {color} {obj}": i + for i, (color, obj) in enumerate((c, o) for c in COLORS for o in OBJECT_TYPES) +} +NUM_MISSIONS = len(MISSION_TO_IDX) + + +# Define the encoder function +def encode_mission_string(mission: bytes) -> torch.Tensor: + mission = mission.decode("utf-8") + clean_mission = mission.replace(" a", "").replace(" the", "") + idx = MISSION_TO_IDX.get(clean_mission, -1) + if idx == -1: + raise ValueError(f"Unknown mission string: {clean_mission}") + return torch.nn.functional.one_hot(torch.tensor(idx), num_classes=NUM_MISSIONS).to(torch.uint8) + + +def main(): + # Download the dataset and apply the string-to-tensor map + data = MinariExperienceReplay( + dataset_id="minigrid/BabyAI-Pickup/optimal-v0", + batch_size=1, + download="force", # uncomment if you want to force redownload + string_to_tensor_map={ + "observations/mission": encode_mission_string # ✅ map mission to one-hot + } + ) + + # Sample and inspect + print("Sampling data with transformed mission field:") + sample = data.sample(batch_size=1) + torchrl_logger.info(sample) + + +if __name__ == "__main__": + main() From 256aeeb41f892cc4aaf8f08c04d13994fe2267d3 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Wed, 9 Jul 2025 17:07:51 +0200 Subject: [PATCH 2/8] Small dumb change --- tutorials/sphinx-tutorials/minari_data_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/sphinx-tutorials/minari_data_loading.py b/tutorials/sphinx-tutorials/minari_data_loading.py index 60c8ea60aa9..07291f61fc5 100644 --- a/tutorials/sphinx-tutorials/minari_data_loading.py +++ b/tutorials/sphinx-tutorials/minari_data_loading.py @@ -27,9 +27,9 @@ def main(): data = MinariExperienceReplay( dataset_id="minigrid/BabyAI-Pickup/optimal-v0", batch_size=1, - download="force", # uncomment if you want to force redownload + download="force", string_to_tensor_map={ - "observations/mission": encode_mission_string # ✅ map mission to one-hot + "observations/mission": encode_mission_string } ) From bc075b1a741863a0dca4a9ccb12fcdf52e78fe11 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Wed, 9 Jul 2025 18:00:25 +0200 Subject: [PATCH 3/8] Documentation and tests --- CONTRIBUTING.md | 3 +++ test/test_libs.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 74edc823196..27a8a232e07 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,6 +32,9 @@ If the generation of this artifact in MacOs M1 doesn't work correctly or in the ARCHFLAGS="-arch arm64" python setup.py develop ``` +In some MacOs devices, the installation of the required libraries errors if the correct version of +clang is not used. Using `llvm@16` (installable with brew), may fix your issues. + ## Formatting your code **Type annotation** diff --git a/test/test_libs.py b/test/test_libs.py index 77d22f2d5b0..af3329f93b4 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3272,7 +3272,6 @@ def _minari_selected_datasets(): @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") -@pytest.mark.slow class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize("selected_dataset", _MINARI_DATASETS) @@ -3337,6 +3336,37 @@ def fn(data): assert sample["data"].shape == torch.Size([32, 8]) assert sample["next", "data"].shape == torch.Size([32, 8]) + @pytest.mark.parametrize("selected_dataset", ["minigrid/BabyAI-Pickup/optimal-v0"]) + def test_minari_string_to_tensor(self, selected_dataset): + # Define the possible missions + colors = ["red", "green", "blue", "purple", "yellow", "grey"] + object_types = ["box", "ball", "key"] + mission_to_idx = { + f"pick up {color} {obj}": i + for i, (color, obj) in enumerate((c, o) for c in colors for o in object_types) + } + num_missions = len(mission_to_idx) + + def encode_mission_string(mission: bytes) -> torch.Tensor: + mission = mission.decode("utf-8") + clean_mission = mission.replace(" a", "").replace(" the", "") + idx = mission_to_idx.get(clean_mission, -1) + if idx == -1: + raise ValueError(f"Unknown mission string: {clean_mission}") + return torch.nn.functional.one_hot(torch.tensor(idx), num_classes=num_missions).to(torch.uint8) + + data = MinariExperienceReplay( + dataset_id=selected_dataset, + batch_size=1, + download="force", + string_to_tensor_map={ + "observations/mission": encode_mission_string + } + ) + sample = data.sample(batch_size=1) + assert isinstance(sample["observation", "mission"], torch.Tensor) + assert sample["observation", "mission"].shape == (1, num_missions) + @pytest.mark.slow class TestRoboset: From 6cec6bc9f881e1018297c342d8bee833fb4f9d45 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Sat, 12 Jul 2025 00:18:40 +0200 Subject: [PATCH 4/8] Final changes --- torchrl/data/datasets/minari_data.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index d0077ec9aa2..f7a3b6c8b7d 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -250,7 +250,7 @@ def _download_and_preproc(self): # Use first episode to allocate structure ref_episode = h5_data.get(episode_dict[0][0]) - td_data.set("episode", torch.zeros((total_steps,), dtype=torch.int64)) + td_data.set("episode", 0) field_max_tracker = {} @@ -262,33 +262,29 @@ def _download_and_preproc(self): path_key = f"{key}/{subkey}" if path_key in self._string_to_tensor_map: encoded = self._string_to_tensor_map[path_key](subval.data[0]) - shape = (total_steps, *encoded.shape) + shape = encoded.shape td_data.set(("observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) else: - shape = (total_steps,) if subval.dim() == 1 else (total_steps, *subval[0].shape) - td_data.set(("observation", subkey), torch.zeros(shape, dtype=subval.dtype)) - td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=subval.dtype)) + td_data.set(("observation", subkey), torch.zeros_like(subval[0])) + td_data.set(("next", "observation", subkey), torch.zeros_like(subval[0])) elif key in ("terminations", "truncations", "rewards"): - td_data.set(("next", match), torch.zeros((total_steps, 1), dtype=val.dtype)) + td_data.set(("next", match), torch.zeros_like(val[0].unsqueeze(-1))) else: - shape = (total_steps,) if val.dim() == 1 else (total_steps, *val[0].shape) - td_data.set((match,), torch.zeros(shape, dtype=val.dtype)) - if key in ("state", "infos"): - td_data.set(("next", match), torch.zeros(shape, dtype=val.dtype)) + td_data.set(match, torch.zeros_like(val[0])) - # Set batch size - td_data.batch_size = [total_steps] + if key in ("state", "infos"): + td_data.set(("next", match), torch.zeros_like(val[0])) - # Set 'done' placeholder + # give it the proper size td_data["next", "done"] = ( td_data["next", "truncated"] | td_data["next", "terminated"] ) if "terminated" in td_data.keys(): td_data["done"] = td_data["truncated"] | td_data["terminated"] - + td_data = td_data.expand(total_steps) # save to designated location torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ") td_data = td_data.memmap_like(str(self.data_path_root)) From ab31a2b89f7de61e433ab2e94fc14472688dfaa4 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Sat, 12 Jul 2025 02:16:29 +0200 Subject: [PATCH 5/8] Adapted for non tensordict observations --- torchrl/data/datasets/minari_data.py | 78 ++++++++++++++++------------ 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index f7a3b6c8b7d..15604cc9c75 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -258,16 +258,22 @@ def _download_and_preproc(self): match = _NAME_MATCH[key] if key == "observations": - for subkey, subval in val.items(): - path_key = f"{key}/{subkey}" - if path_key in self._string_to_tensor_map: - encoded = self._string_to_tensor_map[path_key](subval.data[0]) - shape = encoded.shape - td_data.set(("observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) - td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) - else: - td_data.set(("observation", subkey), torch.zeros_like(subval[0])) - td_data.set(("next", "observation", subkey), torch.zeros_like(subval[0])) + if isinstance(val, PersistentTensorDict): + for subkey, subval in val.items(): + path_key = f"{key}/{subkey}" + if path_key in self._string_to_tensor_map: + encoded = self._string_to_tensor_map[path_key](subval.data[0]) + shape = encoded.shape + td_data.set(("observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) + td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) + else: + td_data.set(("observation", subkey), torch.zeros_like(subval[0])) + td_data.set(("next", "observation", subkey), torch.zeros_like(subval[0])) + elif isinstance(val, torch.Tensor): + td_data.set(("next", match), torch.zeros_like(val[0])) + td_data.set(match, torch.zeros_like(val[0])) + else: + raise Exception("Unexpected data type in observation.") elif key in ("terminations", "truncations", "rewards"): td_data.set(("next", match), torch.zeros_like(val[0].unsqueeze(-1))) @@ -303,28 +309,36 @@ def _download_and_preproc(self): for key, val in episode.items(): match = _NAME_MATCH[key] if key == "observations": - for subkey, subval in val.items(): - path_key = f"{key}/{subkey}" - if path_key in self._string_to_tensor_map: - encoder = self._string_to_tensor_map[path_key] - try: - encoded = [encoder(s) for s in subval.data[:-1]] - next_encoded = [encoder(s) for s in subval.data[1:]] - except Exception as e: - raise RuntimeError( - f"Failed to encode string at {path_key}: {e}" - ) from e - combined = torch.stack(encoded + next_encoded).flatten() - if combined.dtype in (torch.int64, torch.int32, torch.uint8): - max_val = combined.max().item() - field_max_tracker[path_key] = max( - max_val, field_max_tracker.get(path_key, 0) - ) - data_view["observation", subkey].copy_(torch.stack(encoded)) - data_view["next", "observation", subkey].copy_(torch.stack(next_encoded)) - else: - data_view["observation", subkey].copy_(subval[:-1]) - data_view["next", "observation", subkey].copy_(subval[1:]) + if isinstance(val, PersistentTensorDict): + for subkey, subval in val.items(): + path_key = f"{key}/{subkey}" + if path_key in self._string_to_tensor_map: + encoder = self._string_to_tensor_map[path_key] + try: + encoded = [encoder(s) for s in subval.data[:-1]] + next_encoded = [encoder(s) for s in subval.data[1:]] + except Exception as e: + raise RuntimeError( + f"Failed to encode string at {path_key}: {e}" + ) from e + combined = torch.stack(encoded + next_encoded).flatten() + if combined.dtype in (torch.int64, torch.int32, torch.uint8): + max_val = combined.max().item() + field_max_tracker[path_key] = max( + max_val, field_max_tracker.get(path_key, 0) + ) + data_view["observation", subkey].copy_(torch.stack(encoded)) + data_view["next", "observation", subkey].copy_(torch.stack(next_encoded)) + else: + data_view["observation", subkey].copy_(subval[:-1]) + data_view["next", "observation", subkey].copy_(subval[1:]) + elif isinstance(val, torch.Tensor): + if steps != val.shape[0] - 1: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." + ) + data_view["next", match].copy_(val[1:]) + data_view[match].copy_(val[:-1]) elif key in ("state", "infos"): if not val.shape or steps != val.shape[0] - 1: From 9ec0427b9cdb6d3b19303358bfc2fbb60ce0000a Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Tue, 15 Jul 2025 13:08:37 +0200 Subject: [PATCH 6/8] Final adaptation for NonTensorData population --- torchrl/data/datasets/minari_data.py | 39 +++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 15604cc9c75..14197d0396a 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -286,7 +286,7 @@ def _download_and_preproc(self): # give it the proper size td_data["next", "done"] = ( - td_data["next", "truncated"] | td_data["next", "terminated"] + td_data["next", "truncated"] | td_data["next", "terminated"] ) if "terminated" in td_data.keys(): td_data["done"] = td_data["truncated"] | td_data["terminated"] @@ -308,6 +308,15 @@ def _download_and_preproc(self): data_view.fill_("episode", episode_num) for key, val in episode.items(): match = _NAME_MATCH[key] + if key in ("observations", "state", "infos"): + if not val.shape or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) + if steps != val.shape[0] - 1: + raise RuntimeError( + f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." + ) if key == "observations": if isinstance(val, PersistentTensorDict): for subkey, subval in val.items(): @@ -330,8 +339,32 @@ def _download_and_preproc(self): data_view["observation", subkey].copy_(torch.stack(encoded)) data_view["next", "observation", subkey].copy_(torch.stack(next_encoded)) else: - data_view["observation", subkey].copy_(subval[:-1]) - data_view["next", "observation", subkey].copy_(subval[1:]) + if isinstance(data_view["observation", subkey], TensorDict): + data_view["observation"].copy_({subkey: subval[:-1]}) + data_view["next", "observation"].copy_({subkey: subval[1:]}) + elif isinstance(data_view["observation", subkey], list): + # TODO: Unfortunately the copy_ method fais when dealing with + # subvals of NonTensorData. It fails with this + # RuntimeError: Cannot update a leaf NonTensorDataBase from a memmaped + # parent NonTensorStack. To update this leaf node, please update the + # NonTensorStack with the proper index. + # Unfortunately, this following method also fails, as lists do not + # have copy_ method + # data_view["observation", subkey].copy_(subval[:-1]) + # The only approach that seems to be working it unlocking the + # Tensordict. I would prefer something like the following: + # for i in range(len(subval) - 1): + # data_view[i].set(("observation", subkey), subval[i]) + # data_view[i].set(("next", "observation", subkey), subval[i + 1]) + # But this three previous lines give this error: + # RuntimeError: Cannot modify locked TensorDict. For in-place + # modification, consider using the `set_()` method and make + # sure the key is present. + # But this current approach takes incredibly long to complete, maybe + # I should do something different? + with data_view.unlock_(): + data_view.set(("observation", subkey), subval[:-1]) + elif isinstance(val, torch.Tensor): if steps != val.shape[0] - 1: raise RuntimeError( From 7e53301a4c60fba2fb89b2be1a7d2a90ee9d3859 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Wed, 16 Jul 2025 11:02:40 +0200 Subject: [PATCH 7/8] Reversed small test change and fixed patching --- test/test_libs.py | 1 + torchrl/data/datasets/minari_data.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 85e8ba4b641..d98cde540f0 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3379,6 +3379,7 @@ def _minari_init(): @pytest.mark.skipif(not _has_minari or not _has_gymnasium, reason="Minari not found") +@pytest.mark.slow class TestMinari: @pytest.mark.parametrize("split", [False, True]) @pytest.mark.parametrize("dataset_idx", range(20)) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 14197d0396a..25a0d676086 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -256,7 +256,13 @@ def _download_and_preproc(self): for key, val in ref_episode.items(): match = _NAME_MATCH[key] - + if key in ("observations", "state", "infos"): + if ( + not val.shape + ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if val.is_empty(): + continue + val = _patch_info(val) if key == "observations": if isinstance(val, PersistentTensorDict): for subkey, subval in val.items(): @@ -374,14 +380,6 @@ def _download_and_preproc(self): data_view[match].copy_(val[:-1]) elif key in ("state", "infos"): - if not val.shape or steps != val.shape[0] - 1: - if val.is_empty(): - continue - val = _patch_info(val) - if steps != val.shape[0] - 1: - raise RuntimeError( - f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." - ) data_view["next", match].copy_(val[1:]) data_view[match].copy_(val[:-1]) From c5c075d6eb1da858f580392d7567bc78a8679ce0 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Wed, 16 Jul 2025 11:56:27 +0200 Subject: [PATCH 8/8] Pre-commit linting changes --- setup.py | 3 +- test/test_libs.py | 12 ++- torchrl/data/datasets/minari_data.py | 98 ++++++++++++++----- .../sphinx-tutorials/minari_data_loading.py | 10 +- 4 files changed, 85 insertions(+), 38 deletions(-) diff --git a/setup.py b/setup.py index 35f978e7435..3658653806d 100644 --- a/setup.py +++ b/setup.py @@ -251,7 +251,8 @@ def _main(argv): "einops", # For tensor operations "safetensors", # For model loading ], - "all": set()} + "all": set(), + } for key in list(extra_requires.keys()): extra_requires["all"] = extra_requires["all"].union(extra_requires[key]) extra_requires["all"] = sorted(extra_requires["all"]) diff --git a/test/test_libs.py b/test/test_libs.py index d98cde540f0..7ef75cdd8bf 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3460,7 +3460,9 @@ def test_minari_string_to_tensor(self, selected_dataset): object_types = ["box", "ball", "key"] mission_to_idx = { f"pick up {color} {obj}": i - for i, (color, obj) in enumerate((c, o) for c in colors for o in object_types) + for i, (color, obj) in enumerate( + (c, o) for c in colors for o in object_types + ) } num_missions = len(mission_to_idx) @@ -3470,15 +3472,15 @@ def encode_mission_string(mission: bytes) -> torch.Tensor: idx = mission_to_idx.get(clean_mission, -1) if idx == -1: raise ValueError(f"Unknown mission string: {clean_mission}") - return torch.nn.functional.one_hot(torch.tensor(idx), num_classes=num_missions).to(torch.uint8) + return torch.nn.functional.one_hot( + torch.tensor(idx), num_classes=num_missions + ).to(torch.uint8) data = MinariExperienceReplay( dataset_id=selected_dataset, batch_size=1, download="force", - string_to_tensor_map={ - "observations/mission": encode_mission_string - } + string_to_tensor_map={"observations/mission": encode_mission_string}, ) sample = data.sample(batch_size=1) assert isinstance(sample["observation", "mission"], torch.Tensor) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 25a0d676086..4498da2b176 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -50,7 +50,7 @@ "torch.float64": torch.float64, "torch.int64": torch.int64, "torch.int32": torch.int32, - "torch.uint8": torch.uint8 + "torch.uint8": torch.uint8, } @@ -156,7 +156,7 @@ def __init__( prefetch: int | None = None, transform: torchrl.envs.Transform | None = None, # noqa-F821 split_trajs: bool = False, - string_to_tensor_map: dict[str, Callable[[any], torch.Tensor]] | None = None + string_to_tensor_map: dict[str, Callable[[any], torch.Tensor]] | None = None, ): self.dataset_id = dataset_id if root is None: @@ -241,7 +241,7 @@ def _download_and_preproc(self): # populate the tensordict episode_dict = {} for episode_key, episode in h5_data.items(): - episode_num = int(episode_key[len("episode_"):]) + episode_num = int(episode_key[len("episode_") :]) episode_len = episode["actions"].shape[0] episode_dict[episode_num] = (episode_key, episode_len) # Get the total number of steps for the dataset @@ -257,9 +257,7 @@ def _download_and_preproc(self): for key, val in ref_episode.items(): match = _NAME_MATCH[key] if key in ("observations", "state", "infos"): - if ( - not val.shape - ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: + if not val.shape: if val.is_empty(): continue val = _patch_info(val) @@ -268,13 +266,26 @@ def _download_and_preproc(self): for subkey, subval in val.items(): path_key = f"{key}/{subkey}" if path_key in self._string_to_tensor_map: - encoded = self._string_to_tensor_map[path_key](subval.data[0]) + encoded = self._string_to_tensor_map[path_key]( + subval.data[0] + ) shape = encoded.shape - td_data.set(("observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) - td_data.set(("next", "observation", subkey), torch.zeros(shape, dtype=encoded.dtype)) + td_data.set( + ("observation", subkey), + torch.zeros(shape, dtype=encoded.dtype), + ) + td_data.set( + ("next", "observation", subkey), + torch.zeros(shape, dtype=encoded.dtype), + ) else: - td_data.set(("observation", subkey), torch.zeros_like(subval[0])) - td_data.set(("next", "observation", subkey), torch.zeros_like(subval[0])) + td_data.set( + ("observation", subkey), torch.zeros_like(subval[0]) + ) + td_data.set( + ("next", "observation", subkey), + torch.zeros_like(subval[0]), + ) elif isinstance(val, torch.Tensor): td_data.set(("next", match), torch.zeros_like(val[0])) td_data.set(match, torch.zeros_like(val[0])) @@ -330,25 +341,48 @@ def _download_and_preproc(self): if path_key in self._string_to_tensor_map: encoder = self._string_to_tensor_map[path_key] try: - encoded = [encoder(s) for s in subval.data[:-1]] - next_encoded = [encoder(s) for s in subval.data[1:]] + encoded = [ + encoder(s) for s in subval.data[:-1] + ] + next_encoded = [ + encoder(s) for s in subval.data[1:] + ] except Exception as e: raise RuntimeError( f"Failed to encode string at {path_key}: {e}" ) from e - combined = torch.stack(encoded + next_encoded).flatten() - if combined.dtype in (torch.int64, torch.int32, torch.uint8): + combined = torch.stack( + encoded + next_encoded + ).flatten() + if combined.dtype in ( + torch.int64, + torch.int32, + torch.uint8, + ): max_val = combined.max().item() field_max_tracker[path_key] = max( - max_val, field_max_tracker.get(path_key, 0) + max_val, + field_max_tracker.get(path_key, 0), ) - data_view["observation", subkey].copy_(torch.stack(encoded)) - data_view["next", "observation", subkey].copy_(torch.stack(next_encoded)) + data_view["observation", subkey].copy_( + torch.stack(encoded) + ) + data_view["next", "observation", subkey].copy_( + torch.stack(next_encoded) + ) else: - if isinstance(data_view["observation", subkey], TensorDict): - data_view["observation"].copy_({subkey: subval[:-1]}) - data_view["next", "observation"].copy_({subkey: subval[1:]}) - elif isinstance(data_view["observation", subkey], list): + if isinstance( + data_view["observation", subkey], TensorDict + ): + data_view["observation"].copy_( + {subkey: subval[:-1]} + ) + data_view["next", "observation"].copy_( + {subkey: subval[1:]} + ) + elif isinstance( + data_view["observation", subkey], list + ): # TODO: Unfortunately the copy_ method fais when dealing with # subvals of NonTensorData. It fails with this # RuntimeError: Cannot update a leaf NonTensorDataBase from a memmaped @@ -369,7 +403,9 @@ def _download_and_preproc(self): # But this current approach takes incredibly long to complete, maybe # I should do something different? with data_view.unlock_(): - data_view.set(("observation", subkey), subval[:-1]) + data_view.set( + ("observation", subkey), subval[:-1] + ) elif isinstance(val, torch.Tensor): if steps != val.shape[0] - 1: @@ -422,13 +458,16 @@ def _download_and_preproc(self): subval = h5_data.get(episode_key)["observations"][field] sampled_values[path_key] = subval[0] # eager copy except Exception as e: - torchrl_logger.warning(f"Could not extract value for {path_key}: {e}") + torchrl_logger.warning( + f"Could not extract value for {path_key}: {e}" + ) h5_data.close() # Add a "done" entry if self.split_trajs: with td_data.unlock_(): from torchrl.collectors.utils import split_trajectories + td_data = split_trajectories(td_data).memmap_(str(self.data_path)) with open(self.metadata_path, "w") as metadata_file: @@ -444,8 +483,11 @@ def _download_and_preproc(self): example_value = sampled_values[path_key] example_tensor = encoder(example_value) max_val = field_max_tracker.get(path_key, None) - obs_space["subspaces"][field] = self._tensor_to_spec_dict(example_tensor, - max_value=max_val) + obs_space["subspaces"][ + field + ] = self._tensor_to_spec_dict( + example_tensor, max_value=max_val + ) except Exception as e: torchrl_logger.warning( f"Could not encode observation field '{field}' for metadata: {e}" @@ -460,7 +502,9 @@ def _download_and_preproc(self): return td_data @staticmethod - def _tensor_to_spec_dict(tensor: torch.Tensor, max_value: int | None = None) -> dict: + def _tensor_to_spec_dict( + tensor: torch.Tensor, max_value: int | None = None + ) -> dict: if tensor.dtype in (torch.float32, torch.float64, torch.float16): return { "type": "Box", diff --git a/tutorials/sphinx-tutorials/minari_data_loading.py b/tutorials/sphinx-tutorials/minari_data_loading.py index 07291f61fc5..069cacdf5eb 100644 --- a/tutorials/sphinx-tutorials/minari_data_loading.py +++ b/tutorials/sphinx-tutorials/minari_data_loading.py @@ -1,6 +1,6 @@ import torch -from torchrl.data.datasets.minari_data import MinariExperienceReplay from torchrl._utils import logger as torchrl_logger +from torchrl.data.datasets.minari_data import MinariExperienceReplay # Define the possible missions COLORS = ["red", "green", "blue", "purple", "yellow", "grey"] @@ -19,7 +19,9 @@ def encode_mission_string(mission: bytes) -> torch.Tensor: idx = MISSION_TO_IDX.get(clean_mission, -1) if idx == -1: raise ValueError(f"Unknown mission string: {clean_mission}") - return torch.nn.functional.one_hot(torch.tensor(idx), num_classes=NUM_MISSIONS).to(torch.uint8) + return torch.nn.functional.one_hot(torch.tensor(idx), num_classes=NUM_MISSIONS).to( + torch.uint8 + ) def main(): @@ -28,9 +30,7 @@ def main(): dataset_id="minigrid/BabyAI-Pickup/optimal-v0", batch_size=1, download="force", - string_to_tensor_map={ - "observations/mission": encode_mission_string - } + string_to_tensor_map={"observations/mission": encode_mission_string}, ) # Sample and inspect