Skip to content

Commit

Permalink
Merge pull request #713 from egraphs-good/ajpal-push-in
Browse files Browse the repository at this point in the history
Conditional Push In
  • Loading branch information
ajpal authored Jan 28, 2025
2 parents 717da5e + d8336d9 commit 867cc9e
Show file tree
Hide file tree
Showing 13 changed files with 581 additions and 347 deletions.
1 change: 1 addition & 0 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub fn prologue() -> String {
include_str!("optimizations/loop_strength_reduction.egg"),
include_str!("optimizations/ivt.egg"),
include_str!("optimizations/conditional_invariant_code_motion.egg"),
include_str!("optimizations/conditional_push_in.egg"),
include_str!("utility/debug-helper.egg"),
&rulesets(),
]
Expand Down
161 changes: 161 additions & 0 deletions dag_in_context/src/optimizations/conditional_push_in.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
(ruleset push-in)

; new version of the rule where one side of bop is constant
(rule (
(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Bop o (Const c ty outer_ctx) x) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType x (Base x_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single x)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Bop(c, x) in the new regions
(union (Get arg_tr i) (Bop o (Const c new_ins_ty if_tr) (Get arg_tr orig_ins_len)))
(union (Get arg_fa i) (Bop o (Const c new_ins_ty if_fa) (Get arg_fa orig_ins_len)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

(rule (
(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Bop o x (Const c ty outer_ctx)) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType x (Base x_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single x)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Bop(x, c) in the new regions
(union (Get arg_tr i) (Bop o (Get arg_tr orig_ins_len) (Const c new_ins_ty if_tr)))
(union (Get arg_fa i) (Bop o (Get arg_fa orig_ins_len) (Const c new_ins_ty if_fa)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

(rule (
(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Uop o x) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType x (Base x_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single x)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Uop(x) in the new regions
(union (Get arg_tr i) (Uop o (Get arg_tr orig_ins_len)))
(union (Get arg_fa i) (Uop o (Get arg_fa orig_ins_len)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

; OLD VERSION - Too slow for now
; ; push bop input into region
; (rule (
; (= if_e (If pred orig_inputs thn els))
; (ContextOf if_e outer_ctx)
; (= (Bop o x y) (Get orig_inputs i))
; (HasArgType thn (TupleT tylist))
; (HasArgType els (TupleT tylist))
; (HasType x (Base x_ty))
; (HasType y (Base y_ty))
; )
; (
; ; New inputs
; (let new_ins (Concat orig_inputs (Concat (Single x) (Single y))))
; (let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TCons y_ty (TNil))))))

; ; New contexts
; (let if_tr (InIf true pred new_ins))
; (let if_fa (InIf false pred new_ins))

; ; New args
; (let arg_tr (Arg new_ins_ty if_tr))
; (let arg_fa (Arg new_ins_ty if_fa))

; ; SubTuple
; (let orig_ins_len (TypeList-length tylist))
; (let st_tr (SubTuple arg_tr 0 orig_ins_len))
; (let st_fa (SubTuple arg_fa 0 orig_ins_len))

; ; New regions
; (let new_thn (Subst if_tr st_tr thn))
; (let new_els (Subst if_fa st_fa els))

; ; Union the original input with Bop(x, y) in the new regions
; (union (Get arg_tr i) (Bop o (Get arg_tr orig_ins_len) (Get arg_tr (+ orig_ins_len 1))))
; (union (Get arg_fa i) (Bop o (Get arg_fa orig_ins_len) (Get arg_fa (+ orig_ins_len 1))))

; ; Union the ifs
; (union if_e (If pred new_ins new_thn new_els))
; )
; :ruleset push-in)
1 change: 1 addition & 0 deletions dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ fn optimizations() -> Vec<String> {
"loop-inv-motion",
"loop-strength-reduction",
"cicm",
"push-in",
]
.iter()
.map(|opt| opt.to_string())
Expand Down
17 changes: 17 additions & 0 deletions tests/passing/small/if_push_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// ARGS: 20
fn main(input: i64) {
let a: i64 = input * 2;
if (input > 0) {
let x: i64 = a * 5;
} else {
let x: i64 = a * -3;
}

if (x >= 0) {
let res: i64 = x * 10;
} else {
let res: i64 = x * 37;
}

println!("{}", res);
}
34 changes: 17 additions & 17 deletions tests/snapshots/files__block-diamond-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@ expression: visualization.result
---
# ARGS: 1
@main(v0: int) {
c1_: int = const 2;
v2_: bool = lt v0 c1_;
c3_: int = const 0;
c4_: int = const 1;
c1_: int = const 1;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
c4_: int = const 0;
c5_: int = const 5;
v6_: int = id c4_;
v7_: int = id c4_;
v8_: int = id c1_;
br v2_ .b9_ .b10_;
v6_: int = id c1_;
v7_: int = id c1_;
v8_: int = id c2_;
br v3_ .b9_ .b10_;
.b9_:
c11_: bool = const true;
c12_: int = const 4;
v13_: int = select c11_ c12_ c1_;
v13_: int = select c11_ c12_ c2_;
v6_: int = id v13_;
v7_: int = id c4_;
v8_: int = id c1_;
v14_: int = add c1_ v6_;
v15_: int = select v2_ v6_ v14_;
v16_: int = add c4_ v15_;
v7_: int = id c1_;
v8_: int = id c2_;
v14_: int = add c2_ v6_;
v15_: int = select v3_ v6_ v14_;
v16_: int = add c1_ v15_;
print v16_;
ret;
jmp .b17_;
.b10_:
v14_: int = add c1_ v6_;
v15_: int = select v2_ v6_ v14_;
v16_: int = add c4_ v15_;
v14_: int = add c2_ v6_;
v15_: int = select v3_ v6_ v14_;
v16_: int = add c1_ v15_;
print v16_;
ret;
.b17_:
Expand Down
Loading

0 comments on commit 867cc9e

Please sign in to comment.