diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 07c165ef3..75044e98c 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -14,7 +14,9 @@ import asyncio import inspect import logging +import multiprocessing as mp import os +from collections import defaultdict from copy import deepcopy from importlib import reload from itertools import cycle @@ -39,7 +41,8 @@ from litdata.streaming.combined import CombinedStreamingDataset from litdata.streaming.dataset import StreamingDataset from litdata.streaming.parallel import ParallelStreamingDataset -from litdata.streaming.sampler import CacheBatchSampler +from litdata.streaming.reader import PrepareChunksThread +from litdata.streaming.sampler import CacheBatchSampler, ChunkedIndex from litdata.utilities._pytree import tree_flatten from litdata.utilities.base import ( __NUM_CYCLES_KEY__, @@ -835,3 +838,72 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _SingleProcessDataLoaderIter(self) self.check_worker_number_rationality() return _StreamingMultiProcessingDataLoaderIter(self) + + +class LightningDataloader: + # get the order of chunk indices and item indices in which they will be iterated in SD + # start an async downloader for downloading and uncompressing chunk files + # start multiple worker processes whose job be will to read relevant bytes and deserialize, unflatten + # and then store in a dict to be read later and clear after read + def __init__(self, ds: StreamingDataset, batch_size: int = 1, shuffle: bool = False, num_workers: int = 1): + self.ds = ds + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + + # Queues + self.download_queue = asyncio.Queue() # chunk indices to download + self.read_queue = mp.Queue() # (chunk_index, item_index) + self.result_queue = mp.Queue() # (chunk_index, item_index, data) + + # Shared dict to hold ready items + self.data_store = defaultdict(dict) + + # Start CPU workers + self.workers = [ + mp.Process(target=self._reader_worker, args=(self.read_queue, self.result_queue)) + for _ in range(num_workers) + ] + for w in self.workers: + w.daemon = True + w.start() + + # Start result collector + self.collector_proc = mp.Process(target=self._collector, args=(self.result_queue,)) + self.collector_proc.daemon = True + self.collector_proc.start() + + def setup_thread_and_download_chunk(self, index: ChunkedIndex) -> None: + if self._config and (self._config._remote_dir or self._config._compressor): + # Create and start the prepare chunks thread + if self._prepare_thread is None and self._config: + self._prepare_thread = PrepareChunksThread( + self._config, + self._item_loader, + self._distributed_env, + self._max_cache_size, + self._max_pre_download, + self._rank, + ) + # Attach the force download queue + self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore + self._prepare_thread.start() + if index.chunk_indexes: + self._prepare_thread.download(index.chunk_indexes) + self._chunks_queued_for_download = True + + # Only request individual chunk download if: + # 1. We haven't already queued all chunks for the download + # 2. We're processing a new chunk (different from the last one) + if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index: + assert self._prepare_thread + self._prepare_thread.download([index.chunk_index]) + + if self._last_chunk_index is None: + self._last_chunk_index = index.chunk_index + + async def downloader(self): + pass + + def worker(self): + pass