From 73b56a68534d57134eb33c612097c368ae34651c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 6 Feb 2025 17:28:08 -0600 Subject: [PATCH] LoopyKeyBuilder: improve BasicSet handling --- loopy/tools.py | 16 ++++++++++------ test/test_misc.py | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/loopy/tools.py b/loopy/tools.py index e9f9932b7..a4fa95e8a 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -73,13 +73,17 @@ class LoopyKeyBuilder(KeyBuilderBase): update_for_dict = KeyBuilderBase.update_for_constantdict update_for_defaultdict = KeyBuilderBase.update_for_constantdict - def update_for_BasicSet(self, key_hash, key): # noqa - from islpy import Printer - prn = Printer.to_str(key.get_ctx()) - getattr(prn, "print_"+key._base_name)(key) - key_hash.update(prn.get_str().encode("utf8")) + def update_for_BasicSet(self, key_hash, key): # noqa: N802 + key_hash.update(str(type(key)).encode("utf-8")) + self.rec(key_hash, frozenset(key.get_var_dict().keys())) - def update_for_Map(self, key_hash, key): # noqa + constraints = set() + for constraint in key.get_constraints(): + constraints.add(str(constraint).partition("->")[-1]) + + self.rec(key_hash, frozenset(constraints)) + + def update_for_Map(self, key_hash, key): # noqa: N802 if isinstance(key, isl.Map): self.update_for_BasicSet(key_hash, key) else: diff --git a/test/test_misc.py b/test/test_misc.py index cee8431c3..a5a264ce7 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -341,6 +341,31 @@ def test_memoize_on_disk_with_pym_expr(): assert cached_result == uncached_result +def test_basicset_keybuilder(): + # See https://github.com/inducer/loopy/issues/912 for context + import islpy as isl + + # Both sets have the same variables and constraints, but in different order. + # These sets are generated in test_convolution() in test_apps.py + a = isl.BasicSet("[im_w, im_h, nimgs, nfeats] -> " + "{ : im_w >= 7 and im_h >= 7 and nimgs >= 0 and nfeats > 0 }") + + b = isl.BasicSet("[nfeats, nimgs, im_h, im_w] -> " + "{ : nfeats > 0 and nimgs >= 0 and im_h >= 7 and im_w >= 7 }") + + from loopy.tools import LoopyKeyBuilder + + # Equality + assert a == b + assert a.is_equal(b) + assert not a.plain_is_equal(b) + + # Hashing + assert hash(a) != hash(b) + assert a.get_hash() != b.get_hash() + assert LoopyKeyBuilder()(a) == LoopyKeyBuilder()(b) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])