From d8800480cd120a8f0e37674f737bdbb102c2a1b0 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 1 Dec 2022 18:16:42 -0600 Subject: [PATCH] Fix incorrect use of rewrite DB position values Re-positionings of the current rewrites cannot be applied due to #207. --- aeppl/censoring.py | 2 -- aeppl/cumsum.py | 1 - aeppl/mixture.py | 1 - aeppl/rewriting.py | 16 +++++----------- aeppl/scan.py | 6 ++---- aeppl/tensor.py | 7 +++---- 6 files changed, 10 insertions(+), 23 deletions(-) diff --git a/aeppl/censoring.py b/aeppl/censoring.py index 4d46a83f..addba74b 100644 --- a/aeppl/censoring.py +++ b/aeppl/censoring.py @@ -75,7 +75,6 @@ def find_measurable_clips( measurable_ir_rewrites_db.register( "find_measurable_clips", find_measurable_clips, - 0, "basic", "censoring", ) @@ -192,7 +191,6 @@ def find_measurable_roundings( measurable_ir_rewrites_db.register( "find_measurable_roundings", find_measurable_roundings, - 0, "basic", "censoring", ) diff --git a/aeppl/cumsum.py b/aeppl/cumsum.py index ded2f4c1..f66fc3e9 100644 --- a/aeppl/cumsum.py +++ b/aeppl/cumsum.py @@ -83,7 +83,6 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]: measurable_ir_rewrites_db.register( "find_measurable_cumsums", find_measurable_cumsums, - 0, "basic", "cumsum", ) diff --git a/aeppl/mixture.py b/aeppl/mixture.py index 99bcbd78..457fa976 100644 --- a/aeppl/mixture.py +++ b/aeppl/mixture.py @@ -423,7 +423,6 @@ def logprob_MixtureRV( [mixture_replace, switch_mixture_replace], max_use_ratio=aesara.config.optdb__max_use_ratio, ), - 0, "basic", "mixture", ) diff --git a/aeppl/rewriting.py b/aeppl/rewriting.py index bc6802cb..7326359f 100644 --- a/aeppl/rewriting.py +++ b/aeppl/rewriting.py @@ -218,9 +218,7 @@ def incsubtensor_rv_replace(fgraph, node): logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" -logprob_rewrites_db.register( - "pre-canonicalize", optdb.query("+canonicalize"), -10, "basic" -) +logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic") # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require @@ -229,22 +227,18 @@ def incsubtensor_rv_replace(fgraph, node): measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db" logprob_rewrites_db.register( - "measurable_ir_rewrites", measurable_ir_rewrites_db, -10, "basic" + "measurable_ir_rewrites", measurable_ir_rewrites_db, "basic" ) # These rewrites push random/measurable variables "down", making them closer to # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. +measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic") measurable_ir_rewrites_db.register( - "subtensor_lift", local_subtensor_rv_lift, -5, "basic" -) -measurable_ir_rewrites_db.register( - "incsubtensor_lift", incsubtensor_rv_replace, -5, "basic" + "incsubtensor_lift", incsubtensor_rv_replace, "basic" ) -logprob_rewrites_db.register( - "post-canonicalize", optdb.query("+canonicalize"), 10, "basic" -) +logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic") def construct_ir_fgraph( diff --git a/aeppl/scan.py b/aeppl/scan.py index 653b22ef..8b38fd33 100644 --- a/aeppl/scan.py +++ b/aeppl/scan.py @@ -513,7 +513,6 @@ def _get_measurable_outputs_MeasurableScan(op, node): # out2in( # add_opts_to_inner_graphs, name="add_opts_to_inner_graphs", ignore_newtrees=True # ), - -100, "basic", "scan", ) @@ -521,11 +520,10 @@ def _get_measurable_outputs_MeasurableScan(op, node): measurable_ir_rewrites_db.register( "find_measurable_scans", find_measurable_scans, - 0, "basic", "scan", ) # Add scan canonicalizations that aren't in the canonicalization DB -logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, -9, "basic", "scan") -logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, -9, "basic", "scan") +logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, "basic", "scan") +logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, "basic", "scan") diff --git a/aeppl/tensor.py b/aeppl/tensor.py index 0fac7f2b..5292b4ef 100644 --- a/aeppl/tensor.py +++ b/aeppl/tensor.py @@ -273,25 +273,24 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf measurable_ir_rewrites_db.register( - "dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor" + "dimshuffle_lift", local_dimshuffle_rv_lift, "basic", "tensor" ) # We register this later than `dimshuffle_lift` so that it is only applied as a fallback measurable_ir_rewrites_db.register( - "find_measurable_dimshuffles", find_measurable_dimshuffles, 0, "basic", "tensor" + "find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor" ) measurable_ir_rewrites_db.register( - "broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor" + "broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor" ) measurable_ir_rewrites_db.register( "find_measurable_stacks", find_measurable_stacks, - 0, "basic", "tensor", )