Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional Push In #713

Merged
merged 5 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub fn prologue() -> String {
include_str!("optimizations/loop_strength_reduction.egg"),
include_str!("optimizations/ivt.egg"),
include_str!("optimizations/conditional_invariant_code_motion.egg"),
include_str!("optimizations/conditional_push_in.egg"),
include_str!("utility/debug-helper.egg"),
&rulesets(),
]
Expand Down
161 changes: 161 additions & 0 deletions dag_in_context/src/optimizations/conditional_push_in.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
(ruleset push-in)

; new version of the rule where one side of bop is constant
(rule (
(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Bop o (Const c ty outer_ctx) x) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType x (Base x_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single x)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Bop(c, x) in the new regions
(union (Get arg_tr i) (Bop o (Const c new_ins_ty if_tr) (Get arg_tr orig_ins_len)))
(union (Get arg_fa i) (Bop o (Const c new_ins_ty if_fa) (Get arg_fa orig_ins_len)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

(rule (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't interval analysis handle pushing constants in already?

(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Bop o x (Const c ty outer_ctx)) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType x (Base x_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single x)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Bop(x, c) in the new regions
(union (Get arg_tr i) (Bop o (Get arg_tr orig_ins_len) (Const c new_ins_ty if_tr)))
(union (Get arg_fa i) (Bop o (Get arg_fa orig_ins_len) (Const c new_ins_ty if_fa)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

(rule (
(= if_e (If pred orig_inputs thn els))
(ContextOf if_e outer_ctx)
(= (Uop o x) (Get orig_inputs i))
(HasArgType thn (TupleT tylist))
(HasArgType els (TupleT tylist))
(HasType x (Base x_ty))
)
(
; New inputs
(let new_ins (Concat orig_inputs (Single x)))
(let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TNil)))))

; New contexts
(let if_tr (InIf true pred new_ins))
(let if_fa (InIf false pred new_ins))

; New args
(let arg_tr (Arg new_ins_ty if_tr))
(let arg_fa (Arg new_ins_ty if_fa))

; SubTuple
(let orig_ins_len (TypeList-length tylist))
(let st_tr (SubTuple arg_tr 0 orig_ins_len))
(let st_fa (SubTuple arg_fa 0 orig_ins_len))

; New regions
(let new_thn (Subst if_tr st_tr thn))
(let new_els (Subst if_fa st_fa els))

; Union the original input with Uop(x) in the new regions
(union (Get arg_tr i) (Uop o (Get arg_tr orig_ins_len)))
(union (Get arg_fa i) (Uop o (Get arg_fa orig_ins_len)))

; Union the ifs
(union if_e (If pred new_ins new_thn new_els))
)
:ruleset push-in)

; OLD VERSION - Too slow for now
; ; push bop input into region
; (rule (
; (= if_e (If pred orig_inputs thn els))
; (ContextOf if_e outer_ctx)
; (= (Bop o x y) (Get orig_inputs i))
; (HasArgType thn (TupleT tylist))
; (HasArgType els (TupleT tylist))
; (HasType x (Base x_ty))
; (HasType y (Base y_ty))
; )
; (
; ; New inputs
; (let new_ins (Concat orig_inputs (Concat (Single x) (Single y))))
; (let new_ins_ty (TupleT (TLConcat tylist (TCons x_ty (TCons y_ty (TNil))))))

; ; New contexts
; (let if_tr (InIf true pred new_ins))
; (let if_fa (InIf false pred new_ins))

; ; New args
; (let arg_tr (Arg new_ins_ty if_tr))
; (let arg_fa (Arg new_ins_ty if_fa))

; ; SubTuple
; (let orig_ins_len (TypeList-length tylist))
; (let st_tr (SubTuple arg_tr 0 orig_ins_len))
; (let st_fa (SubTuple arg_fa 0 orig_ins_len))

; ; New regions
; (let new_thn (Subst if_tr st_tr thn))
; (let new_els (Subst if_fa st_fa els))

; ; Union the original input with Bop(x, y) in the new regions
; (union (Get arg_tr i) (Bop o (Get arg_tr orig_ins_len) (Get arg_tr (+ orig_ins_len 1))))
; (union (Get arg_fa i) (Bop o (Get arg_fa orig_ins_len) (Get arg_fa (+ orig_ins_len 1))))

; ; Union the ifs
; (union if_e (If pred new_ins new_thn new_els))
; )
; :ruleset push-in)
1 change: 1 addition & 0 deletions dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ fn optimizations() -> Vec<String> {
"loop-inv-motion",
"loop-strength-reduction",
"cicm",
"push-in",
]
.iter()
.map(|opt| opt.to_string())
Expand Down
17 changes: 17 additions & 0 deletions tests/passing/small/if_push_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// ARGS: 20
fn main(input: i64) {
let a: i64 = input * 2;
if (input > 0) {
let x: i64 = a * 5;
} else {
let x: i64 = a * -3;
}

if (x >= 0) {
let res: i64 = x * 10;
} else {
let res: i64 = x * 37;
}

println!("{}", res);
}
34 changes: 17 additions & 17 deletions tests/snapshots/files__block-diamond-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@ expression: visualization.result
---
# ARGS: 1
@main(v0: int) {
c1_: int = const 2;
v2_: bool = lt v0 c1_;
c3_: int = const 0;
c4_: int = const 1;
c1_: int = const 1;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
c4_: int = const 0;
c5_: int = const 5;
v6_: int = id c4_;
v7_: int = id c4_;
v8_: int = id c1_;
br v2_ .b9_ .b10_;
v6_: int = id c1_;
v7_: int = id c1_;
v8_: int = id c2_;
br v3_ .b9_ .b10_;
.b9_:
c11_: bool = const true;
c12_: int = const 4;
v13_: int = select c11_ c12_ c1_;
v13_: int = select c11_ c12_ c2_;
v6_: int = id v13_;
v7_: int = id c4_;
v8_: int = id c1_;
v14_: int = add c1_ v6_;
v15_: int = select v2_ v6_ v14_;
v16_: int = add c4_ v15_;
v7_: int = id c1_;
v8_: int = id c2_;
v14_: int = add c2_ v6_;
v15_: int = select v3_ v6_ v14_;
v16_: int = add c1_ v15_;
print v16_;
ret;
jmp .b17_;
.b10_:
v14_: int = add c1_ v6_;
v15_: int = select v2_ v6_ v14_;
v16_: int = add c4_ v15_;
v14_: int = add c2_ v6_;
v15_: int = select v3_ v6_ v14_;
v16_: int = add c1_ v15_;
print v16_;
ret;
.b17_:
Expand Down
Loading