Skip to content

Commit 78f5e2d

Browse files
committed
tests: FilterSet methods; added extra filter type check
1 parent 4244d3b commit 78f5e2d

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

dill/_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,12 @@ def __repr__(self):
166166
Filter = Union[str, Pattern[str], int, type, FilterFunction]
167167
Rule = Tuple[RuleType, Union[Filter, Iterable[Filter]]]
168168

169-
def _iter(filters):
170-
if isinstance(filters, str):
169+
def _iter(obj):
170+
"""return iterator of object if it's not a string"""
171+
if isinstance(obj, (str, bytes)):
171172
return None
172173
try:
173-
return iter(filters)
174+
return iter(obj)
174175
except TypeError:
175176
return None
176177

@@ -199,7 +200,7 @@ def _match_type(self, filter: Filter) -> Tuple[filter, str]:
199200
else:
200201
filter = re.compile(filter)
201202
field = 'regexes'
202-
elif filter_type is re.Pattern:
203+
elif filter_type is re.Pattern and type(filter.pattern) is str:
203204
field = 'regexes'
204205
elif filter_type is int:
205206
field = 'ids'

dill/tests/test_filtering.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,64 @@
1111
import dill
1212
from dill import _dill
1313
from dill.session import (
14-
EXCLUDE, INCLUDE, FilterRules, RuleType, ipython_filter, size_filter, settings
14+
EXCLUDE, INCLUDE, FilterRules, FilterSet, RuleType, ipython_filter, size_filter, settings
1515
)
1616

17+
def test_filterset():
18+
import re
19+
20+
name = 'test'
21+
regex1 = re.compile(r'\w+\d+')
22+
regex2 = r'_\w+'
23+
id_ = id(FilterSet)
24+
type1 = FilterSet
25+
type2 = 'type:List'
26+
func = lambda obj: obj.name == 'Arthur'
27+
28+
empty_filters = FilterSet()
29+
assert bool(empty_filters) is False
30+
assert len(empty_filters) == 0
31+
assert len([*empty_filters]) == 0
32+
33+
# also tests add() and __ior__() for non-FilterSet other
34+
filters = FilterSet._from_iterable([name, regex1, regex2, id_, type1, type2, func])
35+
assert filters.names == {name}
36+
assert filters.regexes == {regex1, re.compile(regex2)}
37+
assert filters.ids == {id_}
38+
assert filters.types == {type1, list}
39+
assert filters.funcs == {func}
40+
41+
assert bool(filters) is True
42+
assert len(filters) == 7
43+
assert all(x in filters for x in [name, regex1, id_, type1, func])
44+
45+
try:
46+
filters.add(re.compile(b'an 8-bit string regex'))
47+
except ValueError:
48+
pass
49+
else:
50+
raise AssertionError("adding invalid filter should raise error")
51+
52+
filters_copy = filters.copy()
53+
for field in FilterSet._fields:
54+
original, copy = getattr(filters, field), getattr(filters_copy, field)
55+
assert copy is not original
56+
assert copy == original
57+
58+
filters.remove(re.compile(regex2))
59+
assert filters.regexes == {regex1}
60+
filters.discard(list)
61+
filters.discard(list) # should not raise error
62+
assert filters.types == {type1}
63+
assert [*filters] == [name, regex1, id_, type1, func]
64+
65+
# also tests __ior__() for FilterSet other
66+
filters.update(filters_copy)
67+
assert filters.types == {type1, list}
68+
69+
filters.clear()
70+
assert len(filters) == 0
71+
1772
NS = {
1873
'a': 1,
1974
'aa': 2,
@@ -184,6 +239,7 @@ def test_size_filter():
184239
assert did_exclude(NS_copy, filter_size, excluded_subset={'large'})
185240

186241
if __name__ == '__main__':
242+
test_filterset()
187243
test_basic_filtering()
188244
test_exclude_include()
189245
test_add_type()

0 commit comments

Comments
 (0)