Skip to content

Commit cd25393

Browse files
committed
fix(expressions): return fresh residual evaluator instances
1 parent 1bccb5c commit cd25393

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import copy
1718
import math
1819
import threading
1920
from abc import ABC, abstractmethod
@@ -1988,11 +1989,23 @@ def _residual_evaluator_cache_key(
19881989
key=_residual_evaluator_cache_key,
19891990
lock=threading.RLock(),
19901991
)
1991-
def residual_evaluator_of(
1992+
def _cached_residual_evaluator_template(
19921993
spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema
19931994
) -> ResidualEvaluator:
19941995
return (
19951996
UnpartitionedResidualEvaluator(schema=schema, expr=expr)
19961997
if spec.is_unpartitioned()
19971998
else ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive)
19981999
)
2000+
2001+
2002+
def residual_evaluator_of(
2003+
spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema
2004+
) -> ResidualEvaluator:
2005+
"""Create a residual evaluator.
2006+
2007+
Always returns a fresh evaluator instance because evaluators are stateful
2008+
(they set `self.struct` during evaluation) and may be used from multiple
2009+
threads.
2010+
"""
2011+
return copy.copy(_cached_residual_evaluator_template(spec=spec, expr=expr, case_sensitive=case_sensitive, schema=schema))

tests/expressions/test_residual_evaluator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def test_identity_transform_residual() -> None:
8888
assert residual == AlwaysFalse()
8989

9090

91+
def test_residual_evaluator_of_returns_fresh_instance() -> None:
92+
schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()))
93+
spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part"))
94+
predicate = LessThan("dateint", 20170815)
95+
96+
res_eval_1 = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema)
97+
res_eval_2 = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema)
98+
99+
assert res_eval_1 is not res_eval_2
100+
assert res_eval_1.residual_for(Record(20170814)) == AlwaysTrue()
101+
assert res_eval_2.residual_for(Record(20170816)) == AlwaysFalse()
102+
103+
91104
def test_case_insensitive_identity_transform_residuals() -> None:
92105
schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()))
93106

0 commit comments

Comments
 (0)