Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class StreamingDataset(IterableDataset):

def __init__(
self,
input_dir: Union[str, "Dir"],
input_dir: Union[str, "Dir", list[str], list["Dir"]],
cache_dir: Optional[Union[str, "Dir"]] = None,
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
Expand Down Expand Up @@ -93,6 +93,9 @@ def __init__(
"""
_check_version_and_prompt_upgrade(__version__)

if not isinstance(input_dir, (list, tuple)):
input_dir = [input_dir]

super().__init__()
if not isinstance(shuffle, bool):
raise ValueError(f"Shuffle should be a boolean. Found {shuffle}")
Expand All @@ -101,16 +104,26 @@ def __init__(
raise ValueError("subsample must be a float with value greater than 0.")

fnmatch_pattern = None
if isinstance(input_dir, str) and input_dir.endswith(".parquet"):
input_dir, fnmatch_pattern = os.path.split(input_dir)
if len(input_dir) == 1 and isinstance(input_dir[0], str) and input_dir[0].endswith(".parquet"):
input_dir, fnmatch_pattern = os.path.split(input_dir[0])

if len(input_dir) > 1:
raise ValueError("Not implemented")

input_dir = _resolve_dir(input_dir)
input_dir = [input_dir] if not isinstance(input_dir, list) else input_dir
cache_dir = _resolve_dir(cache_dir)

if input_dir.url is not None and input_dir.url.startswith("hf://"):
if any(_dir.url is not None and _dir.url.startswith("hf://") for _dir in input_dir) and len(input_dir) > 1:
raise ValueError(
"Streaming multiple Hugging Face datasets is not supported."
"Please provide a single `hf://` dataset URL. If you need this feature, please open an issue on GitHub."
)

if len(input_dir) == 1 and input_dir[0].url is not None and input_dir[0].url.startswith("hf://"):
if index_path is None:
# No index_path was provided. Attempt to load it from cache or generate it dynamically on the fly.
index_path = index_hf_dataset(input_dir.url, cache_dir.path)
index_path = index_hf_dataset(input_dir[0].url, cache_dir.path)
if item_loader is not None and not isinstance(item_loader, ParquetLoader):
raise ValueError(
"Invalid item_loader for hf://datasets. "
Expand All @@ -122,8 +135,8 @@ def __init__(

self.input_dir = input_dir
self.cache_dir = cache_dir
self.subsampled_files: list[str] = []
self.region_of_interest: list[tuple[int, int]] = []
self.subsampled_files: list[list[str]] = []
self.region_of_interest: list[list[tuple[int, int]]] = []
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
self.input_dir,
self.cache_dir,
Expand Down Expand Up @@ -235,13 +248,14 @@ def set_epoch(self, current_epoch: int) -> None:
self.current_epoch = current_epoch

def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
if _should_replace_path(self.input_dir.path):
if all(_should_replace_path(_dir.path) for _dir in self.input_dir):
cache_path = _try_create_cache_dir(
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url,
cache_dir=self.cache_dir.path,
)
if cache_path is not None:
self.input_dir.path = cache_path
for _dir in self.input_dir:
_dir.path = cache_path

cache = Cache(
input_dir=self.input_dir,
Expand Down
5 changes: 4 additions & 1 deletion src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class CloudProvider(str, Enum):
GCP = "gcp"


def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
def _resolve_dir(dir_path: Optional[Union[str, Path, Dir, list[str], list[Path], list[Dir]]]) -> Dir:
if isinstance(dir_path, (list, tuple)):
return [_resolve_dir(item) for item in dir_path]

if isinstance(dir_path, Dir):
return Dir(path=str(dir_path.path) if dir_path.path else None, url=str(dir_path.url) if dir_path.url else None)

Expand Down
50 changes: 40 additions & 10 deletions src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
import tempfile
import time
from typing import Any, Optional
from typing import Any, Optional, Union

import numpy as np

Expand All @@ -17,7 +17,7 @@


def subsample_streaming_dataset(
input_dir: Dir,
input_dir: Union[Dir, list[Dir]],
cache_dir: Optional[Dir] = None,
item_loader: Optional[BaseItemLoader] = None,
subsample: float = 1.0,
Expand All @@ -27,7 +27,7 @@ def subsample_streaming_dataset(
session_options: Optional[dict] = {},
index_path: Optional[str] = None,
fnmatch_pattern: Optional[str] = None,
) -> tuple[list[str], list[tuple[int, int]]]:
) -> tuple[Union[list[str], list[list[str]]], Union[list[tuple[int, int]]], list[list[tuple[int, int]]]]:
"""Subsample streaming dataset.

But before doing that, we will do some preprocessing:
Expand All @@ -38,6 +38,28 @@ def subsample_streaming_dataset(
- Once chunks are ready, subsample (chunk filenames, region_of_interest).

"""
if isinstance(input_dir, list):
subsampled_files_list = []
roi_list = []

for _dir in input_dir:
_subsampled_files, _roi = subsample_streaming_dataset(
input_dir=_dir,
cache_dir=cache_dir,
item_loader=item_loader,
subsample=subsample,
shuffle=shuffle,
seed=seed,
storage_options=storage_options,
session_options=session_options,
index_path=index_path,
fnmatch_pattern=fnmatch_pattern,
)

subsampled_files_list.extend(_subsampled_files)
roi_list.extend(_roi)
return subsampled_files_list, roi_list

subsampled_files: list[str] = []
roi: list[tuple[int, int]] = []

Expand Down Expand Up @@ -141,12 +163,15 @@ def _should_replace_path(path: Optional[str]) -> bool:


def _read_updated_at(
input_dir: Optional[Dir],
input_dir: Optional[Union[Dir, list[Dir]]],
storage_options: Optional[dict] = {},
session_options: Optional[dict] = {},
index_path: Optional[str] = None,
) -> str:
) -> Union[str, list[str]]:
"""Read last updated timestamp from index.json file."""
if isinstance(input_dir, list):
return [_read_updated_at(_dir, storage_options, session_options, index_path) for _dir in input_dir]

last_updation_timestamp = "0"
index_json_content = None
assert isinstance(input_dir, Dir)
Expand Down Expand Up @@ -215,21 +240,26 @@ def get_default_cache_dir() -> str:


def _try_create_cache_dir(
input_dir: Optional[str],
input_dir: Optional[Union[str, list[str]]],
cache_dir: Optional[str] = None,
storage_options: Optional[dict] = {},
session_options: Optional[dict] = {},
index_path: Optional[str] = None,
) -> Optional[str]:
"""Prepare and return the cache directory for a dataset."""
input_dir = input_dir if isinstance(input_dir, list) else [input_dir] if input_dir else []
resolved_input_dir = _resolve_dir(input_dir)
updated_at = _read_updated_at(resolved_input_dir, storage_options, session_options, index_path)

# Fallback to a hash of the input_dir if updated_at is "0"
if updated_at == "0" and input_dir is not None:
updated_at = generate_md5_hash(input_dir)
updated_at = [updated_at] if not isinstance(updated_at, list) else updated_at

for idx, _upd in enumerate(updated_at):
# Fallback to a hash of the input_dir if updated_at is "0"
if _upd == "0" and input_dir[idx] is not None:
updated_at[idx] = generate_md5_hash(resolved_input_dir[idx].path)

dir_url_hash = generate_md5_hash(resolved_input_dir.url or "")
_input_url = "_".join(_dir.url for _dir in resolved_input_dir if _dir.url is not None)
dir_url_hash = generate_md5_hash(_input_url)

# Determine the cache directory, preferring user-provided cache_dir if given
cache_dir = cache_dir if cache_dir is not None else get_default_cache_dir()
Expand Down
46 changes: 46 additions & 0 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,52 @@ def test_streaming_dataset(tmpdir, monkeypatch, compression):
assert len(dataloader) == 30


@pytest.mark.parametrize(
"compression",
[
pytest.param(None),
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
],
)
@pytest.mark.timeout(30)
@pytest.mark.skip(reason="not implemented")
def test_multi_streaming_dataset(tmpdir, monkeypatch, compression):
seed_everything(42)

with pytest.raises(FileNotFoundError, match="The provided dataset path"):
dataset = StreamingDataset(input_dir=[str(tmpdir.join("tmpfolder"))])

with pytest.raises(ValueError, match="The provided dataset"):
dataset = StreamingDataset(input_dir=[str(tmpdir)])

dir_name = ("dataset1", "dataset2")

for idx, _dir in enumerate(dir_name):
cache = Cache(str(tmpdir.join(_dir)), chunk_size=10, compression=compression)
for j in range(60):
cache[j] = idx * j
cache.done()
cache.merge()

dataset = StreamingDataset(input_dir=[str(tmpdir.join(_dir)) for _dir in dir_name])

assert len(dataset) == 120
for i in range(120):
assert dataset[i] == i

dataset_iter = iter(dataset)
assert len(dataset_iter) == 120
for i in range(120):
assert next(dataset_iter) == i

dataloader = StreamingDataLoader(dataset, num_workers=0, batch_size=1)
assert len(dataloader) == 120
dataloader = DataLoader(dataset, num_workers=2, batch_size=1)
assert len(dataloader) == 120
dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
assert len(dataloader) == 60


def _simple_optimize_fn(index):
return index

Expand Down
Loading