Skip to content

Commit

Permalink
FlattenMapper: guard simplifications that only hold for integers
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 7, 2024
1 parent fa79c6c commit 8da6fb5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
21 changes: 18 additions & 3 deletions pymbolic/mapper/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,19 @@ class FlattenMapper(IdentityMapper[[]]):
This parallels what was done implicitly in the expression node
constructors.
.. automethod:: is_expr_integral
"""

def is_expr_integer_valued(self, expr: ExpressionT) -> bool:
"""A user-supplied method to indicate whether a given *expr* is integer-
valued. This enables additional simplifications that are not valid in
general. The default implementation simply returns *False*.
.. versionadded :: 2024.1
"""
return False

def map_sum(self, expr: p.Sum) -> ExpressionT:
from pymbolic.primitives import flattened_sum
return flattened_sum([
Expand Down Expand Up @@ -77,7 +89,9 @@ def map_floor_div(self, expr: p.FloorDiv) -> ExpressionT:
if p.is_zero(r_num):
return 0
if p.is_zero(r_den - 1):
return r_num
# It's the floor function in this case.
if self.is_expr_integer_valued(r_num):
return r_num

return expr.__class__(r_num, r_den)

Expand All @@ -88,8 +102,9 @@ def map_remainder(self, expr: p.Remainder) -> ExpressionT:
if p.is_zero(r_num):
return 0
if p.is_zero(r_den - 1):
# mod 1 is zero
return 0
# mod 1 is zero for integers, however 3.1 % 1 == .1
if self.is_expr_integer_valued(r_num):
return 0

return expr.__class__(r_num, r_den)

Expand Down
29 changes: 29 additions & 0 deletions test/test_pymbolic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pymbolic.mapper.evaluator import evaluate_kw
from pymbolic.mapper.flattener import FlattenMapper
from pymbolic.mapper.stringifier import StringifyMapper
from pymbolic.typing import ExpressionT

Expand Down Expand Up @@ -1053,6 +1055,33 @@ def test_derived_stringifier() -> None:
# }}}


# {{{ test_flatten

class IntegerFlattenMapper(FlattenMapper):
def is_expr_integer_valued(self, expr: ExpressionT) -> bool:
return True


def test_flatten():
expr = parse("(3 + x) % 1")

assert IntegerFlattenMapper()(expr) != expr
assert FlattenMapper()(expr) == expr

assert evaluate_kw(IntegerFlattenMapper()(expr), x=1) == 0
assert abs(evaluate_kw(FlattenMapper()(expr), x=1.1) - 0.1) < 1e-12

expr = parse("(3 + x) // 1")

assert IntegerFlattenMapper()(expr) != expr
assert FlattenMapper()(expr) == expr

assert evaluate_kw(IntegerFlattenMapper()(expr), x=1) == 4
assert abs(evaluate_kw(FlattenMapper()(expr), x=1.1) - 4) < 1e-12

# }}}


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down

0 comments on commit 8da6fb5

Please sign in to comment.