Skip to content

Commit

Permalink
Fix incorrect use of rewrite DB position values
Browse files Browse the repository at this point in the history
Re-positionings of the current rewrites cannot be applied due to aesara-devs#207.
  • Loading branch information
brandonwillard committed Dec 2, 2022
1 parent f793027 commit 3ce46e4
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 23 deletions.
2 changes: 0 additions & 2 deletions aeppl/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def find_measurable_clips(
measurable_ir_rewrites_db.register(
"find_measurable_clips",
find_measurable_clips,
0,
"basic",
"censoring",
)
Expand Down Expand Up @@ -192,7 +191,6 @@ def find_measurable_roundings(
measurable_ir_rewrites_db.register(
"find_measurable_roundings",
find_measurable_roundings,
0,
"basic",
"censoring",
)
Expand Down
1 change: 0 additions & 1 deletion aeppl/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
measurable_ir_rewrites_db.register(
"find_measurable_cumsums",
find_measurable_cumsums,
0,
"basic",
"cumsum",
)
1 change: 0 additions & 1 deletion aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def logprob_MixtureRV(
[mixture_replace, switch_mixture_replace],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
0,
"basic",
"mixture",
)
16 changes: 5 additions & 11 deletions aeppl/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ def incsubtensor_rv_replace(fgraph, node):

logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
"pre-canonicalize", optdb.query("+canonicalize"), -10, "basic"
)
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")

# These rewrites convert un-measurable variables into their measurable forms,
# but they need to be reapplied, because some of the measurable forms require
Expand All @@ -229,22 +227,18 @@ def incsubtensor_rv_replace(fgraph, node):
measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db"

logprob_rewrites_db.register(
"measurable_ir_rewrites", measurable_ir_rewrites_db, -10, "basic"
"measurable_ir_rewrites", measurable_ir_rewrites_db, "basic"
)

# These rewrites push random/measurable variables "down", making them closer to
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
measurable_ir_rewrites_db.register(
"subtensor_lift", local_subtensor_rv_lift, -5, "basic"
)
measurable_ir_rewrites_db.register(
"incsubtensor_lift", incsubtensor_rv_replace, -5, "basic"
"incsubtensor_lift", incsubtensor_rv_replace, "basic"
)

logprob_rewrites_db.register(
"post-canonicalize", optdb.query("+canonicalize"), 10, "basic"
)
logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")


def construct_ir_fgraph(
Expand Down
6 changes: 2 additions & 4 deletions aeppl/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,19 +513,17 @@ def _get_measurable_outputs_MeasurableScan(op, node):
# out2in(
# add_opts_to_inner_graphs, name="add_opts_to_inner_graphs", ignore_newtrees=True
# ),
-100,
"basic",
"scan",
)

measurable_ir_rewrites_db.register(
"find_measurable_scans",
find_measurable_scans,
0,
"basic",
"scan",
)

# Add scan canonicalizations that aren't in the canonicalization DB
logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, -9, "basic", "scan")
logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, -9, "basic", "scan")
logprob_rewrites_db.register("scan_eqopt1", scan_eqopt1, "basic", "scan")
logprob_rewrites_db.register("scan_eqopt2", scan_eqopt2, "basic", "scan")
7 changes: 3 additions & 4 deletions aeppl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,25 +273,24 @@ def find_measurable_dimshuffles(fgraph, node) -> Optional[List[MeasurableDimShuf


measurable_ir_rewrites_db.register(
"dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic", "tensor"
"dimshuffle_lift", local_dimshuffle_rv_lift, "basic", "tensor"
)


# We register this later than `dimshuffle_lift` so that it is only applied as a fallback
measurable_ir_rewrites_db.register(
"find_measurable_dimshuffles", find_measurable_dimshuffles, 0, "basic", "tensor"
"find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor"
)


measurable_ir_rewrites_db.register(
"broadcast_to_lift", naive_bcast_rv_lift, -5, "basic", "tensor"
"broadcast_to_lift", naive_bcast_rv_lift, "basic", "tensor"
)


measurable_ir_rewrites_db.register(
"find_measurable_stacks",
find_measurable_stacks,
0,
"basic",
"tensor",
)

0 comments on commit 3ce46e4

Please sign in to comment.