Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import asyncio
import inspect
import logging
import multiprocessing as mp
import os
from collections import defaultdict
from copy import deepcopy
from importlib import reload
from itertools import cycle
Expand All @@ -39,7 +41,8 @@
from litdata.streaming.combined import CombinedStreamingDataset
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.parallel import ParallelStreamingDataset
from litdata.streaming.sampler import CacheBatchSampler
from litdata.streaming.reader import PrepareChunksThread
from litdata.streaming.sampler import CacheBatchSampler, ChunkedIndex
from litdata.utilities._pytree import tree_flatten
from litdata.utilities.base import (
__NUM_CYCLES_KEY__,
Expand Down Expand Up @@ -835,3 +838,72 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
return _SingleProcessDataLoaderIter(self)
self.check_worker_number_rationality()
return _StreamingMultiProcessingDataLoaderIter(self)


class LightningDataloader:
# get the order of chunk indices and item indices in which they will be iterated in SD
# start an async downloader for downloading and uncompressing chunk files
# start multiple worker processes whose job be will to read relevant bytes and deserialize, unflatten
# and then store in a dict to be read later and clear after read
def __init__(self, ds: StreamingDataset, batch_size: int = 1, shuffle: bool = False, num_workers: int = 1):
self.ds = ds
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers

# Queues
self.download_queue = asyncio.Queue() # chunk indices to download
self.read_queue = mp.Queue() # (chunk_index, item_index)
self.result_queue = mp.Queue() # (chunk_index, item_index, data)

# Shared dict to hold ready items
self.data_store = defaultdict(dict)

# Start CPU workers
self.workers = [
mp.Process(target=self._reader_worker, args=(self.read_queue, self.result_queue))
for _ in range(num_workers)
]
for w in self.workers:
w.daemon = True
w.start()

# Start result collector
self.collector_proc = mp.Process(target=self._collector, args=(self.result_queue,))
self.collector_proc.daemon = True
self.collector_proc.start()

def setup_thread_and_download_chunk(self, index: ChunkedIndex) -> None:
if self._config and (self._config._remote_dir or self._config._compressor):
# Create and start the prepare chunks thread
if self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(
self._config,
self._item_loader,
self._distributed_env,
self._max_cache_size,
self._max_pre_download,
self._rank,
)
# Attach the force download queue
self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore
self._prepare_thread.start()
if index.chunk_indexes:
self._prepare_thread.download(index.chunk_indexes)
self._chunks_queued_for_download = True

# Only request individual chunk download if:
# 1. We haven't already queued all chunks for the download
# 2. We're processing a new chunk (different from the last one)
if not self._chunks_queued_for_download and index.chunk_index != self._last_chunk_index:
assert self._prepare_thread
self._prepare_thread.download([index.chunk_index])

if self._last_chunk_index is None:
self._last_chunk_index = index.chunk_index

async def downloader(self):
pass

def worker(self):
pass