diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 7b55172d1..89a208490 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -47,7 +47,7 @@ class StreamingDataset(IterableDataset): def __init__( self, - input_dir: Union[str, "Dir"], + input_dir: Union[str, "Dir", list[str], list["Dir"]], cache_dir: Optional[Union[str, "Dir"]] = None, item_loader: Optional[BaseItemLoader] = None, shuffle: bool = False, @@ -93,6 +93,9 @@ def __init__( """ _check_version_and_prompt_upgrade(__version__) + if not isinstance(input_dir, (list, tuple)): + input_dir = [input_dir] + super().__init__() if not isinstance(shuffle, bool): raise ValueError(f"Shuffle should be a boolean. Found {shuffle}") @@ -101,16 +104,26 @@ def __init__( raise ValueError("subsample must be a float with value greater than 0.") fnmatch_pattern = None - if isinstance(input_dir, str) and input_dir.endswith(".parquet"): - input_dir, fnmatch_pattern = os.path.split(input_dir) + if len(input_dir) == 1 and isinstance(input_dir[0], str) and input_dir[0].endswith(".parquet"): + input_dir, fnmatch_pattern = os.path.split(input_dir[0]) + + if len(input_dir) > 1: + raise ValueError("Not implemented") input_dir = _resolve_dir(input_dir) + input_dir = [input_dir] if not isinstance(input_dir, list) else input_dir cache_dir = _resolve_dir(cache_dir) - if input_dir.url is not None and input_dir.url.startswith("hf://"): + if any(_dir.url is not None and _dir.url.startswith("hf://") for _dir in input_dir) and len(input_dir) > 1: + raise ValueError( + "Streaming multiple Hugging Face datasets is not supported." + "Please provide a single `hf://` dataset URL. If you need this feature, please open an issue on GitHub." + ) + + if len(input_dir) == 1 and input_dir[0].url is not None and input_dir[0].url.startswith("hf://"): if index_path is None: # No index_path was provided. Attempt to load it from cache or generate it dynamically on the fly. - index_path = index_hf_dataset(input_dir.url, cache_dir.path) + index_path = index_hf_dataset(input_dir[0].url, cache_dir.path) if item_loader is not None and not isinstance(item_loader, ParquetLoader): raise ValueError( "Invalid item_loader for hf://datasets. " @@ -122,8 +135,8 @@ def __init__( self.input_dir = input_dir self.cache_dir = cache_dir - self.subsampled_files: list[str] = [] - self.region_of_interest: list[tuple[int, int]] = [] + self.subsampled_files: list[list[str]] = [] + self.region_of_interest: list[list[tuple[int, int]]] = [] self.subsampled_files, self.region_of_interest = subsample_streaming_dataset( self.input_dir, self.cache_dir, @@ -235,13 +248,14 @@ def set_epoch(self, current_epoch: int) -> None: self.current_epoch = current_epoch def _create_cache(self, worker_env: _WorkerEnv) -> Cache: - if _should_replace_path(self.input_dir.path): + if all(_should_replace_path(_dir.path) for _dir in self.input_dir): cache_path = _try_create_cache_dir( input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, cache_dir=self.cache_dir.path, ) if cache_path is not None: - self.input_dir.path = cache_path + for _dir in self.input_dir: + _dir.path = cache_path cache = Cache( input_dir=self.input_dir, diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 51db59528..9285a81e2 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -44,7 +44,10 @@ class CloudProvider(str, Enum): GCP = "gcp" -def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir: +def _resolve_dir(dir_path: Optional[Union[str, Path, Dir, list[str], list[Path], list[Dir]]]) -> Dir: + if isinstance(dir_path, (list, tuple)): + return [_resolve_dir(item) for item in dir_path] + if isinstance(dir_path, Dir): return Dir(path=str(dir_path.path) if dir_path.path else None, url=str(dir_path.url) if dir_path.url else None) diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index 6e5bc651f..0f573a4e3 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -5,7 +5,7 @@ import shutil import tempfile import time -from typing import Any, Optional +from typing import Any, Optional, Union import numpy as np @@ -17,7 +17,7 @@ def subsample_streaming_dataset( - input_dir: Dir, + input_dir: Union[Dir, list[Dir]], cache_dir: Optional[Dir] = None, item_loader: Optional[BaseItemLoader] = None, subsample: float = 1.0, @@ -27,7 +27,7 @@ def subsample_streaming_dataset( session_options: Optional[dict] = {}, index_path: Optional[str] = None, fnmatch_pattern: Optional[str] = None, -) -> tuple[list[str], list[tuple[int, int]]]: +) -> tuple[Union[list[str], list[list[str]]], Union[list[tuple[int, int]]], list[list[tuple[int, int]]]]: """Subsample streaming dataset. But before doing that, we will do some preprocessing: @@ -38,6 +38,28 @@ def subsample_streaming_dataset( - Once chunks are ready, subsample (chunk filenames, region_of_interest). """ + if isinstance(input_dir, list): + subsampled_files_list = [] + roi_list = [] + + for _dir in input_dir: + _subsampled_files, _roi = subsample_streaming_dataset( + input_dir=_dir, + cache_dir=cache_dir, + item_loader=item_loader, + subsample=subsample, + shuffle=shuffle, + seed=seed, + storage_options=storage_options, + session_options=session_options, + index_path=index_path, + fnmatch_pattern=fnmatch_pattern, + ) + + subsampled_files_list.extend(_subsampled_files) + roi_list.extend(_roi) + return subsampled_files_list, roi_list + subsampled_files: list[str] = [] roi: list[tuple[int, int]] = [] @@ -141,12 +163,15 @@ def _should_replace_path(path: Optional[str]) -> bool: def _read_updated_at( - input_dir: Optional[Dir], + input_dir: Optional[Union[Dir, list[Dir]]], storage_options: Optional[dict] = {}, session_options: Optional[dict] = {}, index_path: Optional[str] = None, -) -> str: +) -> Union[str, list[str]]: """Read last updated timestamp from index.json file.""" + if isinstance(input_dir, list): + return [_read_updated_at(_dir, storage_options, session_options, index_path) for _dir in input_dir] + last_updation_timestamp = "0" index_json_content = None assert isinstance(input_dir, Dir) @@ -215,21 +240,26 @@ def get_default_cache_dir() -> str: def _try_create_cache_dir( - input_dir: Optional[str], + input_dir: Optional[Union[str, list[str]]], cache_dir: Optional[str] = None, storage_options: Optional[dict] = {}, session_options: Optional[dict] = {}, index_path: Optional[str] = None, ) -> Optional[str]: """Prepare and return the cache directory for a dataset.""" + input_dir = input_dir if isinstance(input_dir, list) else [input_dir] if input_dir else [] resolved_input_dir = _resolve_dir(input_dir) updated_at = _read_updated_at(resolved_input_dir, storage_options, session_options, index_path) - # Fallback to a hash of the input_dir if updated_at is "0" - if updated_at == "0" and input_dir is not None: - updated_at = generate_md5_hash(input_dir) + updated_at = [updated_at] if not isinstance(updated_at, list) else updated_at + + for idx, _upd in enumerate(updated_at): + # Fallback to a hash of the input_dir if updated_at is "0" + if _upd == "0" and input_dir[idx] is not None: + updated_at[idx] = generate_md5_hash(resolved_input_dir[idx].path) - dir_url_hash = generate_md5_hash(resolved_input_dir.url or "") + _input_url = "_".join(_dir.url for _dir in resolved_input_dir if _dir.url is not None) + dir_url_hash = generate_md5_hash(_input_url) # Determine the cache directory, preferring user-provided cache_dir if given cache_dir = cache_dir if cache_dir is not None else get_default_cache_dir() diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index 4332bec79..7e7c54eaa 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -100,6 +100,52 @@ def test_streaming_dataset(tmpdir, monkeypatch, compression): assert len(dataloader) == 30 +@pytest.mark.parametrize( + "compression", + [ + pytest.param(None), + pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")), + ], +) +@pytest.mark.timeout(30) +@pytest.mark.skip(reason="not implemented") +def test_multi_streaming_dataset(tmpdir, monkeypatch, compression): + seed_everything(42) + + with pytest.raises(FileNotFoundError, match="The provided dataset path"): + dataset = StreamingDataset(input_dir=[str(tmpdir.join("tmpfolder"))]) + + with pytest.raises(ValueError, match="The provided dataset"): + dataset = StreamingDataset(input_dir=[str(tmpdir)]) + + dir_name = ("dataset1", "dataset2") + + for idx, _dir in enumerate(dir_name): + cache = Cache(str(tmpdir.join(_dir)), chunk_size=10, compression=compression) + for j in range(60): + cache[j] = idx * j + cache.done() + cache.merge() + + dataset = StreamingDataset(input_dir=[str(tmpdir.join(_dir)) for _dir in dir_name]) + + assert len(dataset) == 120 + for i in range(120): + assert dataset[i] == i + + dataset_iter = iter(dataset) + assert len(dataset_iter) == 120 + for i in range(120): + assert next(dataset_iter) == i + + dataloader = StreamingDataLoader(dataset, num_workers=0, batch_size=1) + assert len(dataloader) == 120 + dataloader = DataLoader(dataset, num_workers=2, batch_size=1) + assert len(dataloader) == 120 + dataloader = DataLoader(dataset, num_workers=2, batch_size=2) + assert len(dataloader) == 60 + + def _simple_optimize_fn(index): return index