diff --git a/.github/codespell_ignore_words.txt b/.github/codespell_ignore_words.txt new file mode 100644 index 000000000..f14cd5c5a --- /dev/null +++ b/.github/codespell_ignore_words.txt @@ -0,0 +1,3 @@ +# this file is used to ignore words in the codespell check (pre-commit) + +crate diff --git a/.gitignore b/.gitignore index edf1cf67c..116c9ccd6 100644 --- a/.gitignore +++ b/.gitignore @@ -120,3 +120,16 @@ status.json # use the below name for your optimize dataset directory for examples example_optimize_dataset + + +# --- rust .gitignore --- +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7f16c074..0385691c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,7 @@ repos: hooks: - id: codespell additional_dependencies: [tomli] + args: ["--ignore-words=.github/codespell_ignore_words.txt"] exclude: > (?x)^( .*\.ipynb diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000..aeee0ae5b --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,171 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "cfg-if" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "litdata" +version = "0.2.52" +dependencies = [ + "pyo3", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "proc-macro2" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bc3fcb250e53458e712715cf74285c1f889686520d79294a9ef3bd7aa1fc619" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..124c37bc2 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "litdata" +version = "0.2.52" +edition = "2021" +authors = ["Lightning AI et al. "] +license = "Apache-2.0" +description = "Data processing and streaming library for fast AI model training." +documentation = "https://github.com/Lightning-AI/litdata/" +homepage = "https://github.com/Lightning-AI/litdata/" +repository = "https://github.com/Lightning-AI/litdata/" +keywords = ["deep learning", "pytorch", "AI", "streaming", "cloud", "data processing"] +readme = "README.md" + +[lib] +name = "_core" +# "cdylib" is necessary to produce a shared library for Python to import from. +crate-type = ["cdylib"] + +[dependencies] +# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) +# "abi3-py39" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.9 +pyo3 = { version = "0.22.4", features = ["extension-module", "abi3-py39"] } diff --git a/Makefile b/Makefile index 6590db912..a542a8434 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ export SPHINX_MOCK_REQUIREMENTS=0 setup: install-dependencies install-pre-commit @echo "==================== Setup Finished ====================" @echo "All set! Ready to go!" + uv pip install -U pyopenssl test: clean uv pip install -q -r requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 14d9790d3..f00345f07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,19 @@ # limitations under the License. [build-system] +build-backend = "maturin" + requires = [ + "maturin>=1,<2", "setuptools", "wheel", ] +[tool.maturin] +module-name = "litdata._core" +python-packages = [ "litdata" ] +python-source = "src" + [tool.ruff] target-version = "py39" line-length = 120 diff --git a/requirements.txt b/requirements.txt index b2c754162..1529b3d9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ boto3 requests tifffile obstore +maturin diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 000000000..b8591112a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,21 @@ +// pip install --upgrade pyOpenSSL cryptography + +use pyo3::prelude::*; +pub mod litdata_core; + +#[pyfunction] +fn hello_from_bin() -> String { + "RUST: Hello from LitData!".to_string() +} + +/// A Python module implemented in Rust. The name of this function (`_core`) must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule] +fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(hello_from_bin, m)?)?; + m.add_class::()?; + // m.add_class::()?; + // m.add_class::()?; + Ok(()) +} diff --git a/src/litdata/__init__.py b/src/litdata/__init__.py index e46acecda..5536d8c27 100644 --- a/src/litdata/__init__.py +++ b/src/litdata/__init__.py @@ -13,6 +13,7 @@ import warnings from litdata.__about__ import * # noqa: F403 +from litdata._core import hello_from_bin from litdata.constants import _LIGHTNING_SDK_AVAILABLE from litdata.processing.functions import map, merge_datasets, optimize, walk from litdata.raw.dataset import StreamingRawDataset @@ -47,6 +48,7 @@ "index_parquet_dataset", "index_hf_dataset", "breakpoint", + "hello_from_bin", ] if _LIGHTNING_SDK_AVAILABLE: diff --git a/src/litdata/_core.pyi b/src/litdata/_core.pyi new file mode 100644 index 000000000..6e238dbf7 --- /dev/null +++ b/src/litdata/_core.pyi @@ -0,0 +1,59 @@ +def hello_from_bin() -> str: ... + +# StreamingDataProvider +# -> on start, download x upcoming items in advance +# -> get_next_k_item() => get next k upcomig items +# +# ------ how it works ------ +# 1. ChunksConfig has a property `self.streaming_data_provider` which is an instance of StreamingDataProvider +# 2. When dataset.py __iter__() is called, it gets the chunk order and __next__() will get the sample item order. +# 3. The chunk order and sample item order is stored in `set_chunk` and `set_sample_index`. +# 4. But, we will not only get chunk and sample order for current epoch, but also for next epoch to be better prepared. +# 5. For dataset's epoch 1, we will call on_start() to download offset array for all chunk indexes in parallel. +# 6. Downloaded items returned by on_start() and in future by get_next_k_item() +# are deserialized and then stored in `config.index_to_sample_data`. +# 7. when an item read is requested, get_next_k_item() will be called to get the next k items. +# 8. For every subsequent epoch (2, 3, ...), we will get the chunk and sample order for the next epoch +# and then call `set_chunk` and `set_sample_index` to update the chunk and sample order for next epoch. +class StreamingDataProvider: + def __init__( + self, + epoch: int, + remote_dir: str, + chunks: list[dict[str, str]], + on_start_pre_item_download_count: int, + get_next_k_item_count: int, + ) -> None: ... + def on_start(self) -> list[tuple[int, int, int, bytes]]: ... + def get_next_k_item(self) -> list[tuple[int, int, int, bytes]]: ... + def set_epoch(self, epoch: int) -> None: ... + def set_chunk_and_sample_index(self, epoch: int, chunk_index: list[int], sample_index: list[list[int]]) -> None: ... + def set_chunk( + self, epoch: int, chunk_index: list[int], chunk_index_begin: list[tuple[int, int, int, int]] + ) -> None: ... + def set_sample_index(self, epoch: int, sample_index: list[list[int]]) -> None: ... + +# S3Storage +class S3Storage: + def __init__(self, remote_dir: str) -> None: ... + def byte_range_download(self, remote_path: str, local_path: str, num_workers: int) -> None: ... + +class LitDataLoaderCore: + index: int + worker_chunks: list[int] + worker_intervals: list[tuple[int, int]] + batch_size: int + pre_download: int + prefetch_workers: int + prefetch_factor: int + + def __init__( + self, + worker_chunks: list[int], + worker_intervals: list[tuple[int, int]], + batch_size: int, + pre_download: int, + prefetch_workers: int, + prefetch_factor: int, + ) -> None: ... + def __iter__(self) -> any: ... diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 07c165ef3..0b1d519ec 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -33,6 +33,7 @@ ) from torch.utils.data.sampler import BatchSampler, Sampler +from litdata._core import LitDataLoaderCore from litdata.constants import _DEFAULT_CHUNK_BYTES, _VIZ_TRACKER_AVAILABLE from litdata.debugger import _get_log_msg from litdata.streaming import Cache @@ -835,3 +836,40 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _SingleProcessDataLoaderIter(self) self.check_worker_number_rationality() return _StreamingMultiProcessingDataLoaderIter(self) + + +class LitDataLoader: + def __init__( + self, + dataset: StreamingDataset, + batch_size: int = 1, + shuffle: bool = False, + num_workers: int = 0, + seed: int = 17, + ) -> None: + assert isinstance(dataset, StreamingDataset) + assert batch_size > 0, "batch_size should be a positive integer" + assert num_workers >= 0, "num_workers should be a non-negative integer" + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.dataset.shuffle = shuffle + + self.num_workers = num_workers + self.seed = seed + worker_chunks, worker_intervals = self.dataset.get_worker_chunks_and_intervals() + print(f"Worker chunks: {worker_chunks}") + print(f"Worker intervals: {worker_intervals}") + self.lit_data_loader = LitDataLoaderCore( + worker_chunks=worker_chunks, + worker_intervals=worker_intervals, + batch_size=batch_size, + pre_download=1, + prefetch_workers=1, + prefetch_factor=1, + ) + + def __iter__(self) -> Any: + for batch in self.lit_data_loader: + yield {"data": batch} diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 7b55172d1..a04c7ee3e 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -294,6 +294,21 @@ def get_len(self, num_workers: int, batch_size: int) -> int: self.shuffler = self._create_shuffler(cache) return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch) + def get_worker_chunks_and_intervals(self) -> tuple[list[int], list[list[int]]]: + if self.worker_env is None: + self.worker_env = _WorkerEnv.detect() + if self.cache is None: + self.cache = self._create_cache(worker_env=self.worker_env) + if self.shuffler is None: + self.shuffler = self._create_shuffler(self.cache) + + workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers( + self.distributed_env, self.worker_env.world_size, self.batch_size, self.current_epoch + ) + + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + return workers_chunks[worker_rank], workers_intervals[worker_rank] + def __iter__(self) -> "StreamingDataset": # When the StreamingDataset is used within map or optimize, let's refetch the distributed env. logger.debug(_get_log_msg({"name": "iterating_dataset", "ph": "B"})) diff --git a/src/litdata_core/downloader.rs b/src/litdata_core/downloader.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/litdata_core/mod.rs b/src/litdata_core/mod.rs new file mode 100644 index 000000000..9c54a0f29 --- /dev/null +++ b/src/litdata_core/mod.rs @@ -0,0 +1,50 @@ +use pyo3::prelude::*; +pub mod downloader; + +#[pyclass] +pub struct LitDataLoaderCore { + index: usize, + worker_chunks: Vec, + worker_intervals: Vec>, + batch_size: u32, // number of chunks to be processed in a batch + pre_download: u32, // number of chunks to pre-download ahead of current chunk + prefetch_workers: u32, // number of workers to be used for download & decompressing chunk files + prefetch_factor: u32, // number of batches to prefetch ahead of current batch +} + +#[pymethods] +impl LitDataLoaderCore { + #[new] + fn new( + worker_chunks: Vec, + worker_intervals: Vec>, + batch_size: u32, + pre_download: u32, + prefetch_workers: u32, + prefetch_factor: u32, + ) -> Self { + LitDataLoaderCore { + index: 0, + worker_chunks, + worker_intervals, + batch_size, + pre_download, + prefetch_workers, + prefetch_factor, + } + } + + fn __iter__(slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + if slf.index < slf.worker_chunks.len() { + let item = slf.worker_chunks[slf.index]; + slf.index += 1; + Some(item) + } else { + None // signals StopIteration in Python + } + } +}