-
-
Notifications
You must be signed in to change notification settings - Fork 82
Implement ReiterableLazyIterable and use it in to_multi_transformable_collection #537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,8 @@ | |
import multiprocessing as mp | ||
import random | ||
import numpy as np | ||
from collections.abc import Iterable | ||
from typing import Callable | ||
from collections.abc import Iterable, Iterator | ||
from typing import Callable, List | ||
|
||
import abc | ||
import pipeline_dp.combiners as dp_combiners | ||
|
@@ -474,11 +474,39 @@ def to_list(self, col, stage_name: str): | |
raise NotImplementedError("to_list is not implement in SparkBackend.") | ||
|
||
|
||
class ReiterableLazyIterable(Iterable): | ||
"""A lazy iterable that can be iterated multiple times. | ||
|
||
It generates elements on the first iteration and stores them. | ||
Subsequent iterations yield the stored elements. | ||
""" | ||
|
||
def __init__(self, iterable: Iterable): | ||
"""Initializes the ReiterableLazyIterable. | ||
|
||
Args: | ||
iterable: Iterable to make reiterable | ||
""" | ||
self._iterable = iterable | ||
self._cache: List = None | ||
self._first_run_complete = False | ||
|
||
def __iter__(self) -> Iterator: | ||
if not self._first_run_complete: | ||
self._cache = [] | ||
for item in self._iterable: | ||
self._cache.append(item) | ||
yield item | ||
self._first_run_complete = True | ||
else: | ||
yield from self._cache | ||
|
||
|
||
class LocalBackend(PipelineBackend): | ||
"""Local Pipeline adapter.""" | ||
|
||
def to_multi_transformable_collection(self, col): | ||
return list(col) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not to use list? For laziness? let's add it in the comment of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, for laziness. Comment added |
||
return ReiterableLazyIterable(col) | ||
|
||
def map(self, col, fn, stage_name: typing.Optional[str] = None): | ||
return map(fn, col) | ||
|
@@ -520,6 +548,8 @@ def filter_by_key( | |
keys_to_keep, | ||
stage_name: typing.Optional[str] = None, | ||
): | ||
if not isinstance(keys_to_keep, set): | ||
keys_to_keep = set(keys_to_keep) | ||
return (kv for kv in col if kv[0] in keys_to_keep) | ||
|
||
def keys(self, col, stage_name: typing.Optional[str] = None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,10 +146,6 @@ def filter_by_key_with_sharding(backend: pipeline_backend.PipelineBackend, col, | |
lambda p: tuple((p, i) for i in range(sharding_factor)), | ||
f"Shard partitions into {sharding_factor} keys", | ||
) | ||
# to_multi_transformable_collection is no-op for not LocalMode. For | ||
# local mode it is transform iterable to list, which is neded because | ||
# filter_by_key requires list. | ||
keys_to_keep = backend.to_multi_transformable_collection(keys_to_keep) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't we need it anymore? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we moved this logic inside LocalBackend.filter_by_key |
||
|
||
col_filtered = backend.filter_by_key(col, keys_to_keep, stage_name) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -427,6 +427,11 @@ def setUpClass(cls): | |
privacy_id_extractor=lambda x: x[0], | ||
value_extractor=lambda x: x[2]) | ||
|
||
def test_to_multi_transformable_collection(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add local to have aligned names |
||
col = self.backend.to_multi_transformable_collection(range(5)) | ||
self.assertEqual(list(col), [0, 1, 2, 3, 4]) | ||
self.assertEqual(list(col), [0, 1, 2, 3, 4]) | ||
|
||
def test_local_map(self): | ||
self.assertEqual(list(self.backend.map([], lambda x: x / 0)), []) | ||
|
||
|
@@ -588,7 +593,7 @@ def assert_laziness(operator, *args): | |
assert_laziness(self.backend.sum_per_key) | ||
assert_laziness(self.backend.flat_map, str) | ||
assert_laziness(self.backend.sample_fixed_per_key, int) | ||
assert_laziness(self.backend.filter_by_key, list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why so? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need somthing that it's converted to set. |
||
assert_laziness(self.backend.filter_by_key, [1, 2]) | ||
assert_laziness(self.backend.distinct, str) | ||
|
||
def test_local_sample_fixed_per_key_requires_no_discarding(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not to initialize it to
[]
immediately?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._cache = [] wouldn't work, because this list will shared between instances