From 55b8229f39812a94b0a1d8804ae5a5140c5644bb Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Fri, 15 Aug 2025 15:00:03 +0530 Subject: [PATCH 1/3] rust setup done --- .github/codespell_ignore_words.txt | 3 + .gitignore | 13 +++ .pre-commit-config.yaml | 1 + Cargo.lock | 171 +++++++++++++++++++++++++++++ Cargo.toml | 22 ++++ pyproject.toml | 8 ++ requirements.txt | 1 + src/lib.rs | 17 +++ src/litdata/__init__.py | 2 + src/litdata/_core.pyi | 39 +++++++ 10 files changed, 277 insertions(+) create mode 100644 .github/codespell_ignore_words.txt create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 src/lib.rs create mode 100644 src/litdata/_core.pyi diff --git a/.github/codespell_ignore_words.txt b/.github/codespell_ignore_words.txt new file mode 100644 index 000000000..f14cd5c5a --- /dev/null +++ b/.github/codespell_ignore_words.txt @@ -0,0 +1,3 @@ +# this file is used to ignore words in the codespell check (pre-commit) + +crate diff --git a/.gitignore b/.gitignore index edf1cf67c..116c9ccd6 100644 --- a/.gitignore +++ b/.gitignore @@ -120,3 +120,16 @@ status.json # use the below name for your optimize dataset directory for examples example_optimize_dataset + + +# --- rust .gitignore --- +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7f16c074..0385691c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,7 @@ repos: hooks: - id: codespell additional_dependencies: [tomli] + args: ["--ignore-words=.github/codespell_ignore_words.txt"] exclude: > (?x)^( .*\.ipynb diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000..aeee0ae5b --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,171 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "cfg-if" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "litdata" +version = "0.2.52" +dependencies = [ + "pyo3", +] + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "proc-macro2" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bc3fcb250e53458e712715cf74285c1f889686520d79294a9ef3bd7aa1fc619" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..124c37bc2 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "litdata" +version = "0.2.52" +edition = "2021" +authors = ["Lightning AI et al. "] +license = "Apache-2.0" +description = "Data processing and streaming library for fast AI model training." +documentation = "https://github.com/Lightning-AI/litdata/" +homepage = "https://github.com/Lightning-AI/litdata/" +repository = "https://github.com/Lightning-AI/litdata/" +keywords = ["deep learning", "pytorch", "AI", "streaming", "cloud", "data processing"] +readme = "README.md" + +[lib] +name = "_core" +# "cdylib" is necessary to produce a shared library for Python to import from. +crate-type = ["cdylib"] + +[dependencies] +# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) +# "abi3-py39" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.9 +pyo3 = { version = "0.22.4", features = ["extension-module", "abi3-py39"] } diff --git a/pyproject.toml b/pyproject.toml index 14d9790d3..f00345f07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,19 @@ # limitations under the License. [build-system] +build-backend = "maturin" + requires = [ + "maturin>=1,<2", "setuptools", "wheel", ] +[tool.maturin] +module-name = "litdata._core" +python-packages = [ "litdata" ] +python-source = "src" + [tool.ruff] target-version = "py39" line-length = 120 diff --git a/requirements.txt b/requirements.txt index b2c754162..1529b3d9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ boto3 requests tifffile obstore +maturin diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 000000000..2b71bd22c --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,17 @@ +use pyo3::prelude::*; + +#[pyfunction] +fn hello_from_bin() -> String { + "RUST: Hello from LitData!".to_string() +} + +/// A Python module implemented in Rust. The name of this function (`_core`) must match +/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to +/// import the module. +#[pymodule] +fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(hello_from_bin, m)?)?; + // m.add_class::()?; + // m.add_class::()?; + Ok(()) +} diff --git a/src/litdata/__init__.py b/src/litdata/__init__.py index e46acecda..5536d8c27 100644 --- a/src/litdata/__init__.py +++ b/src/litdata/__init__.py @@ -13,6 +13,7 @@ import warnings from litdata.__about__ import * # noqa: F403 +from litdata._core import hello_from_bin from litdata.constants import _LIGHTNING_SDK_AVAILABLE from litdata.processing.functions import map, merge_datasets, optimize, walk from litdata.raw.dataset import StreamingRawDataset @@ -47,6 +48,7 @@ "index_parquet_dataset", "index_hf_dataset", "breakpoint", + "hello_from_bin", ] if _LIGHTNING_SDK_AVAILABLE: diff --git a/src/litdata/_core.pyi b/src/litdata/_core.pyi new file mode 100644 index 000000000..dc8f3a87c --- /dev/null +++ b/src/litdata/_core.pyi @@ -0,0 +1,39 @@ +def hello_from_bin() -> str: ... + +# StreamingDataProvider +# -> on start, download x upcoming items in advance +# -> get_next_k_item() => get next k upcomig items +# +# ------ how it works ------ +# 1. ChunksConfig has a property `self.streaming_data_provider` which is an instance of StreamingDataProvider +# 2. When dataset.py __iter__() is called, it gets the chunk order and __next__() will get the sample item order. +# 3. The chunk order and sample item order is stored in `set_chunk` and `set_sample_index`. +# 4. But, we will not only get chunk and sample order for current epoch, but also for next epoch to be better prepared. +# 5. For dataset's epoch 1, we will call on_start() to download offset array for all chunk indexes in parallel. +# 6. Downloaded items returned by on_start() and in future by get_next_k_item() +# are deserialized and then stored in `config.index_to_sample_data`. +# 7. when an item read is requested, get_next_k_item() will be called to get the next k items. +# 8. For every subsequent epoch (2, 3, ...), we will get the chunk and sample order for the next epoch +# and then call `set_chunk` and `set_sample_index` to update the chunk and sample order for next epoch. +class StreamingDataProvider: + def __init__( + self, + epoch: int, + remote_dir: str, + chunks: list[dict[str, str]], + on_start_pre_item_download_count: int, + get_next_k_item_count: int, + ) -> None: ... + def on_start(self) -> list[tuple[int, int, int, bytes]]: ... + def get_next_k_item(self) -> list[tuple[int, int, int, bytes]]: ... + def set_epoch(self, epoch: int) -> None: ... + def set_chunk_and_sample_index(self, epoch: int, chunk_index: list[int], sample_index: list[list[int]]) -> None: ... + def set_chunk( + self, epoch: int, chunk_index: list[int], chunk_index_begin: list[tuple[int, int, int, int]] + ) -> None: ... + def set_sample_index(self, epoch: int, sample_index: list[list[int]]) -> None: ... + +# S3Storage +class S3Storage: + def __init__(self, remote_dir: str) -> None: ... + def byte_range_download(self, remote_path: str, local_path: str, num_workers: int) -> None: ... From d3e1d6e3f071615ca9f5f269b9e51cd890545e23 Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Mon, 25 Aug 2025 17:45:32 +0530 Subject: [PATCH 2/3] update --- src/lib.rs | 2 + src/litdata_core/mod.rs | 1 + src/litdata_core/serializers.rs | 141 ++++++++++++++++++++++++ src/litdata_core/serializers/numbers.rs | 0 4 files changed, 144 insertions(+) create mode 100644 src/litdata_core/mod.rs create mode 100644 src/litdata_core/serializers.rs create mode 100644 src/litdata_core/serializers/numbers.rs diff --git a/src/lib.rs b/src/lib.rs index 2b71bd22c..065ef9bbd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ use pyo3::prelude::*; +pub mod litdata_core; + #[pyfunction] fn hello_from_bin() -> String { "RUST: Hello from LitData!".to_string() diff --git a/src/litdata_core/mod.rs b/src/litdata_core/mod.rs new file mode 100644 index 000000000..486063662 --- /dev/null +++ b/src/litdata_core/mod.rs @@ -0,0 +1 @@ +pub mod serializers; diff --git a/src/litdata_core/serializers.rs b/src/litdata_core/serializers.rs new file mode 100644 index 000000000..3dc8ffa01 --- /dev/null +++ b/src/litdata_core/serializers.rs @@ -0,0 +1,141 @@ +use pyo3::prelude::*; +use pyo3::types::{PyAnyMethods, PyBool, PyFloat, PyLong}; +use std::fs; + +pub trait Serializer { + /// Serialize data into a byte vector and optional metadata string + fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)>; + + /// Deserialize from bytes back to Python object + fn deserialize(&self, data: &[u8], py: Python) -> PyResult; + + /// Check if the data can be serialized by this serializer + fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool; + + /// Optional setup hook (e.g., for metadata) + fn setup(&mut self, _metadata: Option<&PyAny>) -> PyResult<()> { + Ok(()) + } +} + +pub struct IntegerSerializer; + +impl IntegerSerializer { + pub fn new() -> Self { + Self + } +} + +impl Serializer for IntegerSerializer { + fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { + let val: i64 = data.extract()?; // now uses extract_bound internally + Ok((val.to_le_bytes().to_vec(), None)) + } + + fn deserialize(&self, data: &[u8], py: Python) -> PyResult { + if data.len() != 8 { + return Err(pyo3::exceptions::PyValueError::new_err( + "Invalid byte length for i64", + )); + } + let val = i64::from_le_bytes(data.try_into().unwrap()); + Ok(val.into_py(py)) + } + + fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { + data.is_instance_of::() + } +} + +pub struct FloatSerializer; + +impl FloatSerializer { + pub fn new() -> Self { + Self + } +} + +impl Serializer for FloatSerializer { + fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { + let val: f64 = data.extract()?; // Bound supports extract directly + Ok((val.to_le_bytes().to_vec(), None)) + } + + fn deserialize(&self, data: &[u8], py: Python) -> PyResult { + if data.len() != 8 { + return Err(pyo3::exceptions::PyValueError::new_err( + "Invalid byte length for f64", + )); + } + let val = f64::from_le_bytes(data.try_into().unwrap()); + Ok(val.into_py(py)) + } + + fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { + data.is_instance_of::() + } +} + +/// String Serializer + +pub struct StringSerializer; + +impl StringSerializer { + pub fn new() -> Self { + Self + } +} + +impl Serializer for StringSerializer { + fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { + // Own the string to avoid any lifetime ties to 'py + let s: String = data.extract()?; + Ok((s.into_bytes(), None)) + } + + fn deserialize(&self, data: &[u8], py: Python) -> PyResult { + let s = std::str::from_utf8(data) + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid UTF-8 data"))?; + Ok(s.into_py(py)) + } + + fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { + // Extract owned String; if it's a path to a file, reject + if let Ok(s) = data.extract::() { + fs::metadata(&s).is_err() + } else { + false + } + } +} + +/// Boolean Serializer +pub struct BooleanSerializer; + +impl BooleanSerializer { + pub fn new() -> Self { + Self + } +} + +impl Serializer for BooleanSerializer { + fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { + let val: bool = data.extract()?; + // Single-byte representation, 1 for true, 0 for false + Ok((vec![val as u8], None)) + } + + fn deserialize(&self, data: &[u8], py: Python) -> PyResult { + if data.len() != 1 { + return Err(pyo3::exceptions::PyValueError::new_err( + "Invalid byte length for bool", + )); + } + let val = data[0] != 0; + Ok(val.into_py(py)) + } + + fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { + data.is_instance_of::() + } +} diff --git a/src/litdata_core/serializers/numbers.rs b/src/litdata_core/serializers/numbers.rs new file mode 100644 index 000000000..e69de29bb From 851472cc743ed1f9dc11369c961beed4a2794ecd Mon Sep 17 00:00:00 2001 From: Deependu Jha Date: Wed, 1 Oct 2025 17:27:37 +0530 Subject: [PATCH 3/3] who knows --- Makefile | 1 + src/lib.rs | 4 +- src/litdata/_core.pyi | 20 +++ src/litdata/streaming/dataloader.py | 38 +++++ src/litdata/streaming/dataset.py | 15 ++ .../{serializers/numbers.rs => downloader.rs} | 0 src/litdata_core/mod.rs | 51 ++++++- src/litdata_core/serializers.rs | 141 ------------------ 8 files changed, 127 insertions(+), 143 deletions(-) rename src/litdata_core/{serializers/numbers.rs => downloader.rs} (100%) delete mode 100644 src/litdata_core/serializers.rs diff --git a/Makefile b/Makefile index 6590db912..a542a8434 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ export SPHINX_MOCK_REQUIREMENTS=0 setup: install-dependencies install-pre-commit @echo "==================== Setup Finished ====================" @echo "All set! Ready to go!" + uv pip install -U pyopenssl test: clean uv pip install -q -r requirements.txt diff --git a/src/lib.rs b/src/lib.rs index 065ef9bbd..b8591112a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ -use pyo3::prelude::*; +// pip install --upgrade pyOpenSSL cryptography +use pyo3::prelude::*; pub mod litdata_core; #[pyfunction] @@ -13,6 +14,7 @@ fn hello_from_bin() -> String { #[pymodule] fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(hello_from_bin, m)?)?; + m.add_class::()?; // m.add_class::()?; // m.add_class::()?; Ok(()) diff --git a/src/litdata/_core.pyi b/src/litdata/_core.pyi index dc8f3a87c..6e238dbf7 100644 --- a/src/litdata/_core.pyi +++ b/src/litdata/_core.pyi @@ -37,3 +37,23 @@ class StreamingDataProvider: class S3Storage: def __init__(self, remote_dir: str) -> None: ... def byte_range_download(self, remote_path: str, local_path: str, num_workers: int) -> None: ... + +class LitDataLoaderCore: + index: int + worker_chunks: list[int] + worker_intervals: list[tuple[int, int]] + batch_size: int + pre_download: int + prefetch_workers: int + prefetch_factor: int + + def __init__( + self, + worker_chunks: list[int], + worker_intervals: list[tuple[int, int]], + batch_size: int, + pre_download: int, + prefetch_workers: int, + prefetch_factor: int, + ) -> None: ... + def __iter__(self) -> any: ... diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 07c165ef3..0b1d519ec 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -33,6 +33,7 @@ ) from torch.utils.data.sampler import BatchSampler, Sampler +from litdata._core import LitDataLoaderCore from litdata.constants import _DEFAULT_CHUNK_BYTES, _VIZ_TRACKER_AVAILABLE from litdata.debugger import _get_log_msg from litdata.streaming import Cache @@ -835,3 +836,40 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _SingleProcessDataLoaderIter(self) self.check_worker_number_rationality() return _StreamingMultiProcessingDataLoaderIter(self) + + +class LitDataLoader: + def __init__( + self, + dataset: StreamingDataset, + batch_size: int = 1, + shuffle: bool = False, + num_workers: int = 0, + seed: int = 17, + ) -> None: + assert isinstance(dataset, StreamingDataset) + assert batch_size > 0, "batch_size should be a positive integer" + assert num_workers >= 0, "num_workers should be a non-negative integer" + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.dataset.shuffle = shuffle + + self.num_workers = num_workers + self.seed = seed + worker_chunks, worker_intervals = self.dataset.get_worker_chunks_and_intervals() + print(f"Worker chunks: {worker_chunks}") + print(f"Worker intervals: {worker_intervals}") + self.lit_data_loader = LitDataLoaderCore( + worker_chunks=worker_chunks, + worker_intervals=worker_intervals, + batch_size=batch_size, + pre_download=1, + prefetch_workers=1, + prefetch_factor=1, + ) + + def __iter__(self) -> Any: + for batch in self.lit_data_loader: + yield {"data": batch} diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 7b55172d1..a04c7ee3e 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -294,6 +294,21 @@ def get_len(self, num_workers: int, batch_size: int) -> int: self.shuffler = self._create_shuffler(cache) return self.shuffler.get_len(self.distributed_env, self.num_workers, self.batch_size, self.current_epoch) + def get_worker_chunks_and_intervals(self) -> tuple[list[int], list[list[int]]]: + if self.worker_env is None: + self.worker_env = _WorkerEnv.detect() + if self.cache is None: + self.cache = self._create_cache(worker_env=self.worker_env) + if self.shuffler is None: + self.shuffler = self._create_shuffler(self.cache) + + workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers( + self.distributed_env, self.worker_env.world_size, self.batch_size, self.current_epoch + ) + + worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank + return workers_chunks[worker_rank], workers_intervals[worker_rank] + def __iter__(self) -> "StreamingDataset": # When the StreamingDataset is used within map or optimize, let's refetch the distributed env. logger.debug(_get_log_msg({"name": "iterating_dataset", "ph": "B"})) diff --git a/src/litdata_core/serializers/numbers.rs b/src/litdata_core/downloader.rs similarity index 100% rename from src/litdata_core/serializers/numbers.rs rename to src/litdata_core/downloader.rs diff --git a/src/litdata_core/mod.rs b/src/litdata_core/mod.rs index 486063662..9c54a0f29 100644 --- a/src/litdata_core/mod.rs +++ b/src/litdata_core/mod.rs @@ -1 +1,50 @@ -pub mod serializers; +use pyo3::prelude::*; +pub mod downloader; + +#[pyclass] +pub struct LitDataLoaderCore { + index: usize, + worker_chunks: Vec, + worker_intervals: Vec>, + batch_size: u32, // number of chunks to be processed in a batch + pre_download: u32, // number of chunks to pre-download ahead of current chunk + prefetch_workers: u32, // number of workers to be used for download & decompressing chunk files + prefetch_factor: u32, // number of batches to prefetch ahead of current batch +} + +#[pymethods] +impl LitDataLoaderCore { + #[new] + fn new( + worker_chunks: Vec, + worker_intervals: Vec>, + batch_size: u32, + pre_download: u32, + prefetch_workers: u32, + prefetch_factor: u32, + ) -> Self { + LitDataLoaderCore { + index: 0, + worker_chunks, + worker_intervals, + batch_size, + pre_download, + prefetch_workers, + prefetch_factor, + } + } + + fn __iter__(slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + if slf.index < slf.worker_chunks.len() { + let item = slf.worker_chunks[slf.index]; + slf.index += 1; + Some(item) + } else { + None // signals StopIteration in Python + } + } +} diff --git a/src/litdata_core/serializers.rs b/src/litdata_core/serializers.rs deleted file mode 100644 index 3dc8ffa01..000000000 --- a/src/litdata_core/serializers.rs +++ /dev/null @@ -1,141 +0,0 @@ -use pyo3::prelude::*; -use pyo3::types::{PyAnyMethods, PyBool, PyFloat, PyLong}; -use std::fs; - -pub trait Serializer { - /// Serialize data into a byte vector and optional metadata string - fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)>; - - /// Deserialize from bytes back to Python object - fn deserialize(&self, data: &[u8], py: Python) -> PyResult; - - /// Check if the data can be serialized by this serializer - fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool; - - /// Optional setup hook (e.g., for metadata) - fn setup(&mut self, _metadata: Option<&PyAny>) -> PyResult<()> { - Ok(()) - } -} - -pub struct IntegerSerializer; - -impl IntegerSerializer { - pub fn new() -> Self { - Self - } -} - -impl Serializer for IntegerSerializer { - fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { - let val: i64 = data.extract()?; // now uses extract_bound internally - Ok((val.to_le_bytes().to_vec(), None)) - } - - fn deserialize(&self, data: &[u8], py: Python) -> PyResult { - if data.len() != 8 { - return Err(pyo3::exceptions::PyValueError::new_err( - "Invalid byte length for i64", - )); - } - let val = i64::from_le_bytes(data.try_into().unwrap()); - Ok(val.into_py(py)) - } - - fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { - data.is_instance_of::() - } -} - -pub struct FloatSerializer; - -impl FloatSerializer { - pub fn new() -> Self { - Self - } -} - -impl Serializer for FloatSerializer { - fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { - let val: f64 = data.extract()?; // Bound supports extract directly - Ok((val.to_le_bytes().to_vec(), None)) - } - - fn deserialize(&self, data: &[u8], py: Python) -> PyResult { - if data.len() != 8 { - return Err(pyo3::exceptions::PyValueError::new_err( - "Invalid byte length for f64", - )); - } - let val = f64::from_le_bytes(data.try_into().unwrap()); - Ok(val.into_py(py)) - } - - fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { - data.is_instance_of::() - } -} - -/// String Serializer - -pub struct StringSerializer; - -impl StringSerializer { - pub fn new() -> Self { - Self - } -} - -impl Serializer for StringSerializer { - fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { - // Own the string to avoid any lifetime ties to 'py - let s: String = data.extract()?; - Ok((s.into_bytes(), None)) - } - - fn deserialize(&self, data: &[u8], py: Python) -> PyResult { - let s = std::str::from_utf8(data) - .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid UTF-8 data"))?; - Ok(s.into_py(py)) - } - - fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { - // Extract owned String; if it's a path to a file, reject - if let Ok(s) = data.extract::() { - fs::metadata(&s).is_err() - } else { - false - } - } -} - -/// Boolean Serializer -pub struct BooleanSerializer; - -impl BooleanSerializer { - pub fn new() -> Self { - Self - } -} - -impl Serializer for BooleanSerializer { - fn serialize<'py>(&self, data: Bound<'py, PyAny>) -> PyResult<(Vec, Option)> { - let val: bool = data.extract()?; - // Single-byte representation, 1 for true, 0 for false - Ok((vec![val as u8], None)) - } - - fn deserialize(&self, data: &[u8], py: Python) -> PyResult { - if data.len() != 1 { - return Err(pyo3::exceptions::PyValueError::new_err( - "Invalid byte length for bool", - )); - } - let val = data[0] != 0; - Ok(val.into_py(py)) - } - - fn can_serialize<'py>(&self, data: Bound<'py, PyAny>) -> bool { - data.is_instance_of::() - } -}