Skip to content

Commit

Permalink
dev(narugo): add max count
Browse files Browse the repository at this point in the history
narugo1992 committed Sep 23, 2024
1 parent 2442371 commit 86bd537
Showing 2 changed files with 45 additions and 9 deletions.
28 changes: 22 additions & 6 deletions cheesechaser/datapool/base.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
import logging
import os
import shutil
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
@@ -135,12 +136,13 @@ def mock_resource(self, resource_id, resource_info, silent: bool = False) -> Con

def batch_download_to_directory(self, resource_ids, dst_dir: str, max_workers: int = 12,
save_metainfo: bool = True, metainfo_fmt: str = '{resource_id}_metainfo.json',
silent: bool = False):
max_downloads: Optional[int] = None, silent: bool = False):
"""
Download multiple resources to a directory.
This method downloads a batch of resources to a specified directory, optionally saving metadata for each resource.
It uses a thread pool to parallelize downloads for improved performance.
This method downloads a batch of resources to a specified directory,
optionally saving metadata for each resource. It uses a thread pool to parallelize
downloads for improved performance.
:param resource_ids: List of resource IDs or tuples of (resource_id, resource_info) to download.
:type resource_ids: Iterable[Union[str, Tuple[str, Any]]]
@@ -152,6 +154,8 @@ def batch_download_to_directory(self, resource_ids, dst_dir: str, max_workers: i
:type save_metainfo: bool
:param metainfo_fmt: Format string for metadata filenames.
:type metainfo_fmt: str
:param max_downloads: Max download number of this task, unlimited when not given.
:type max_downloads: Optional[int]
:param silent: If True, suppresses progress bar of each standalone files during the mocking process.
:type silent: bool
@@ -162,10 +166,18 @@ def batch_download_to_directory(self, resource_ids, dst_dir: str, max_workers: i
>>> data_pool.batch_download_to_directory(['resource1', 'resource2'], '/path/to/destination')
"""
pg_res = tqdm(resource_ids, desc='Batch Downloading')
pg_downloaded = tqdm(desc='Files Downloaded')
pg_file_download = tqdm(desc='Files Downloaded')
pg_download = tqdm(desc='Download Count', total=max_downloads)
os.makedirs(dst_dir, exist_ok=True)

is_completed = threading.Event()
downloaded_count = 0

def _func(resource_id, resource_info):
nonlocal downloaded_count
if is_completed.is_set():
return

try:
with self.mock_resource(resource_id, resource_info, silent=silent) as (td, resource_info):
copied = False
@@ -180,11 +192,15 @@ def _func(resource_id, resource_info):
meta_file = os.path.join(td, metainfo_fmt.format(resource_id=resource_id))
with open(meta_file, 'w') as f:
json.dump(resource_info, f, indent=4, sort_keys=True, ensure_ascii=False)

pg_downloaded.update()
pg_file_download.update()
copied = True

if not copied:
logging.warning(f'No files found for resource {resource_id!r}.')
downloaded_count += 1
if max_downloads is not None and downloaded_count >= max_downloads:
is_completed.set()
pg_download.update()
except ResourceNotFoundError:
logging.warning(f'Resource {resource_id!r} not found, skipped.')
except Exception as err:
26 changes: 23 additions & 3 deletions cheesechaser/pipe/base.py
Original file line number Diff line number Diff line change
@@ -80,13 +80,18 @@ class PipeSession:
:type is_stopped: Event
:param is_finished: An event indicating whether the session has finished.
:type is_finished: Event
:param max_count: Max item count for iterating from the data source. Unlimited when not given.
:type max_count: Optional[int]
"""

def __init__(self, queue: Queue, is_start: Event, is_stopped: Event, is_finished: Event):
def __init__(self, queue: Queue, is_start: Event, is_stopped: Event, is_finished: Event,
max_count: Optional[int] = None):
self.queue = queue
self.is_start = is_start
self.is_stopped = is_stopped
self.is_finished = is_finished
self.max_count: Optional[int] = max_count
self._current_count: int = 0

def next(self, block: bool = True, timeout: Optional[float] = None) -> PipeItem:
"""
@@ -104,18 +109,27 @@ def next(self, block: bool = True, timeout: Optional[float] = None) -> PipeItem:
self.is_start.set()
return self.queue.get(block=block, timeout=timeout)

def _count_update(self, n: int = 1):
self._current_count += n
if self.max_count is not None and self._current_count >= self.max_count:
self.is_stopped.set()

def __iter__(self) -> Iterator[PipeItem]:
"""
Iterate over the items in the pipeline.
:return: An iterator of PipeItems.
:rtype: Iterator[PipeItem]
"""
pg = tqdm(desc='Piped Items', total=self.max_count)
self._count_update(0)
while not (self.is_stopped.is_set() and self.queue.empty()):
try:
data = self.next(block=True, timeout=1.0)
if isinstance(data, PipeItem):
pg.update()
yield data
self._count_update()
except Empty:
pass

@@ -177,13 +191,18 @@ def retrieve(self, resource_id, resource_metainfo, silent: bool = False):
"""
raise NotImplementedError # pragma: no cover

def batch_retrieve(self, resource_ids, max_workers: int = 12, silent: bool = False) -> PipeSession:
def batch_retrieve(self, resource_ids, max_workers: int = 12, max_count: Optional[int] = None,
silent: bool = False) -> PipeSession:
"""
Retrieve multiple resources in parallel using a thread pool.
:param resource_ids: An iterable of resource IDs or (ID, metainfo) tuples to retrieve.
:param max_workers: The maximum number of worker threads to use.
:type max_workers: int
:param max_count: Max item count for iterating from the data source. Unlimited when not given.
:type max_count: Optional[int]
:param silent: If True, suppresses progress bar of each standalone files during the mocking process.
:type silent: bool
:return: A PipeSession object for iterating over the retrieved items.
:rtype: PipeSession
"""
@@ -272,5 +291,6 @@ def _productor():
queue=queue,
is_start=is_started,
is_stopped=is_stopped,
is_finished=is_finished
is_finished=is_finished,
max_count=max_count,
)

0 comments on commit 86bd537

Please sign in to comment.