Skip to content

Commit

Permalink
Merge pull request #727 from egraphs-good/mem-simple
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa authored Feb 12, 2025
2 parents 15e3228 + 200c6a7 commit 85bb784
Show file tree
Hide file tree
Showing 29 changed files with 925 additions and 231 deletions.
1 change: 1 addition & 0 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub fn prologue() -> String {
include_str!("optimizations/peepholes.egg"),
&optimizations::memory::rules(),
include_str!("optimizations/memory.egg"),
include_str!("optimizations/mem_simple.egg"),
&optimizations::loop_invariant::rules().join("\n"),
include_str!("optimizations/loop_simplify.egg"),
include_str!("optimizations/loop_unroll.egg"),
Expand Down
127 changes: 127 additions & 0 deletions dag_in_context/src/optimizations/mem_simple.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@

(ruleset mem-simple)

; ============================
; NoAlias analysis
; ============================

(relation NoAlias (Expr Expr))

(rule ((Bop (PtrAdd) e i)
(= (lo-bound i) (IntB lo))
(> lo 0))
((NoAlias e (Bop (PtrAdd) e i)))
:ruleset mem-simple)

(rule ((Bop (PtrAdd) e i)
(= (hi-bound i) (IntB hi))
(< hi 0))
((NoAlias e (Bop (PtrAdd) e i)))
:ruleset mem-simple)

(rule ((= p1 (Bop (PtrAdd) p i))
(= p2 (Bop (PtrAdd) p (Bop (Add) i diff)))
(= (lo-bound diff) (IntB lo))
(> lo 0))
((NoAlias p1 p2))
:ruleset mem-simple)

(rule ((= p1 (Bop (PtrAdd) p i))
(= p2 (Bop (PtrAdd) p (Bop (Add) i diff)))
(= (hi-bound diff) (IntB hi))
(< hi 0))
((NoAlias p1 p2))
:ruleset mem-simple)

(rule ((= p1 (Bop (PtrAdd) p i))
(= p2 (Bop (PtrAdd) p (Bop (Sub) i diff)))
(= (lo-bound diff) (IntB lo))
(> lo 0))
((NoAlias p1 p2))
:ruleset mem-simple)

(rule ((= p1 (Bop (PtrAdd) p i))
(= p2 (Bop (PtrAdd) p (Bop (Sub) i diff)))
(= (hi-bound diff) (IntB hi))
(< hi 0))
((NoAlias p1 p2))
:ruleset mem-simple)

(rule ((NoAlias x y))
((NoAlias y x))
:ruleset mem-simple)

; ============================
; Memory optimizations
; ============================

(relation DidMemOptimization (String))

; A write then a load to different addresses can be swapped
(rule ((NoAlias write-addr load-addr)
(= write (Top (Write) write-addr write-val state))
(= load (Bop (Load) load-addr write)))
((let new-load (Bop (Load) load-addr state))
(union
(Get load 1)
(Top (Write) write-addr write-val (Get new-load 1)))
(union (Get load 0) (Get new-load 0))
(DidMemOptimization "commute write then load")
)
:ruleset mem-simple)

; A load then a write to different addresses can be swapped
; Actually, does this break WeaklyLinear if the stored value depends on the
; loaded value? Commenting this out for now.
; (rule ((NoAlias load-addr write-addr)
; (= load (Bop (Load) load-addr state))
; (= write (Top (Write) write-addr write-val (Get load 1))))
; ((let new-write (Top (Write) write-addr write-val state))
; (let new-load (Bop (Load) load-addr new-write))
; (union write (Get new-load 1))
; (union (Get load 0) (Get new-load 0))
; (DidMemOptimization "commute load then write")
; )
; :ruleset mem-simple)

; Two loads to the same address can be compressed
(rule ((= first-load (Bop (Load) addr state))
(= second-load (Bop (Load) addr first-load)))
((union (Get first-load 0) (Get second-load 0))
(union (Get first-load 1) (Get second-load 1))
(DidMemOptimization "duplicate load")
)
:ruleset mem-simple)

; A write and a load to the same address can be forwarded
(rule ((= write (Top (Write) addr write-val state))
(= load (Bop (Load) addr write)))
((union (Get load 0) write-val)
(union (Get load 1) write)
(DidMemOptimization "store forward")
)
:ruleset mem-simple)

; Two writes of the same value to the same address can be compressed
(rule ((= first-write (Top (Write) addr write-val state))
(= second-write (Top (Write) addr write-val first-write)))
((union first-write second-write)
(DidMemOptimization "duplicate write"))
:ruleset mem-simple)

; A write shadows a previous write to the same address
(rule ((= first-write (Top (Write) addr shadowed-val state))
(= second-write (Top (Write) addr write-val first-write)))
((union second-write (Top (Write) addr write-val state))
(DidMemOptimization "shadowed write"))
:ruleset mem-simple)

; A load doesn't change the state
; TODO: why does this break weaklylinear?
; (rule ((= load (Bop (Load) addr state)))
; ((union (Get load 1) state))
; :ruleset mem-simple)

; (rule ((DidMemOptimization _))
; ((panic "DidMemOptimization"))
; :ruleset mem-simple)
10 changes: 10 additions & 0 deletions dag_in_context/src/optimizations/peepholes.egg
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@
:ruleset peepholes)

(rewrite (Top (Select) pred x x) x :ruleset peepholes)

; constant fold `(x + const1) + const2` even when x is not constant
(rewrite (Bop (Add) (Bop (Add) x (Const (Int i) ty ctx)) (Const (Int j) ty ctx))
(Bop (Add) x (Const (Int (+ i j)) ty ctx))
:ruleset peepholes)

; ptradd(ptradd(p, x), y) => ptradd(p, x + y)
(rewrite (Bop (PtrAdd) (Bop (PtrAdd) p x) y)
(Bop (PtrAdd) p (Bop (Add) x y))
:ruleset peepholes)
2 changes: 2 additions & 0 deletions dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub(crate) fn helpers() -> String {
(saturate
terms-helpers
(saturate terms-helpers-helpers)))
(saturate mem-simple)
;; cicm index
cicm-index
Expand Down Expand Up @@ -215,6 +216,7 @@ pub fn parallel_schedule() -> Vec<CompilerPass> {
)
{helpers}
add-to-debug-expr
)
"
)),
Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/utility/util.egg
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@
((subsume (If a b c d)))
:ruleset subsume-after-helpers)


(ruleset add-to-debug-expr)
42 changes: 42 additions & 0 deletions tests/passing/small/count.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# ARGS: 100000

@main(halfinput : int) {
two: int = const 2;
input: int = mul halfinput two;
zero: int = const 0;
one: int = const 1;
vals: ptr<int> = alloc input;
store vals zero;
vals_1: ptr<int> = ptradd vals one;
store vals_1 one;
i: int = const 2;
.loop:
cond: bool = lt i input;
br cond .body .done;
.body:
neg_one: int = const -1;
one: int = const 1;
two: int = const 2;

vals_i: ptr<int> = ptradd vals i;
vals_i_minus_one: ptr<int> = ptradd vals_i neg_one;
tmp: int = load vals_i_minus_one;
tmp: int = add tmp one;
store vals_i tmp;
i: int = add i one;

vals_i: ptr<int> = ptradd vals i;
vals_i_minus_one: ptr<int> = ptradd vals_i neg_one;
tmp: int = load vals_i_minus_one;
tmp: int = add tmp one;
store vals_i tmp;
i: int = add i one;
jmp .loop;
.done:
last_plus_one: ptr<int> = ptradd vals i;
neg_one: int = const -1;
last: ptr<int> = ptradd last_plus_one neg_one;
tmp: int = load last;
free vals;
print tmp;
}
46 changes: 46 additions & 0 deletions tests/passing/small/fib-2-unroll.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# ARGS: 100000

@main(input : int) {
zero: int = const 0;
one: int = const 1;
vals: ptr<int> = alloc input;
store vals zero;
vals_i: ptr<int> = ptradd vals one;
store vals_i one;
i: int = const 2;
.loop:
cond: bool = lt i input;
br cond .body .done;
.body:
neg_one: int = const -1;
neg_two: int = const -2;
vals_i: ptr<int> = ptradd vals i;
vals_i_minus_one: ptr<int> = ptradd vals_i neg_one;
vals_i_minus_two: ptr<int> = ptradd vals_i neg_two;
tmp: int = load vals_i_minus_one;
tmp2: int = load vals_i_minus_two;
tmp: int = add tmp tmp2;
store vals_i tmp;
i: int = add i one;
cond: bool = lt i input;
br cond .body2 .done;
.body2:
neg_one: int = const -1;
neg_two: int = const -2;
vals_i: ptr<int> = ptradd vals i;
vals_i_minus_one: ptr<int> = ptradd vals_i neg_one;
vals_i_minus_two: ptr<int> = ptradd vals_i neg_two;
tmp: int = load vals_i_minus_one;
tmp2: int = load vals_i_minus_two;
tmp: int = add tmp tmp2;
store vals_i tmp;
i: int = add i one;
jmp .loop;
.done:
last_plus_one: ptr<int> = ptradd vals i;
neg_one_: int = const -1;
last: ptr<int> = ptradd last_plus_one neg_one_;
tmp: int = load last;
free vals;
print tmp;
}
33 changes: 33 additions & 0 deletions tests/passing/small/fib-2.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# ARGS: 100000

@main(input : int) {
zero: int = const 0;
one: int = const 1;
vals: ptr<int> = alloc input;
store vals zero;
vals_i: ptr<int> = ptradd vals one;
store vals_i one;
i: int = const 2;
.loop:
cond: bool = lt i input;
br cond .body .done;
.body:
neg_one: int = const -1;
neg_two: int = const -2;
vals_i: ptr<int> = ptradd vals i;
vals_i_minus_one: ptr<int> = ptradd vals_i neg_one;
vals_i_minus_two: ptr<int> = ptradd vals_i neg_two;
tmp: int = load vals_i_minus_one;
tmp2: int = load vals_i_minus_two;
tmp: int = add tmp tmp2;
store vals_i tmp;
i: int = add i one;
jmp .loop;
.done:
last_plus_one: ptr<int> = ptradd vals i;
neg_one_: int = const -1;
last: ptr<int> = ptradd last_plus_one neg_one_;
tmp: int = load last;
free vals;
print tmp;
}
25 changes: 12 additions & 13 deletions tests/snapshots/files__branch_hoisting-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@ expression: visualization.result
v9_: bool = id v3_;
.b10_:
c11_: int = const 1;
v12_: int = add c11_ v5_;
v13_: int = add c11_ v12_;
c12_: int = const 2;
v13_: int = add c12_ v5_;
v14_: int = add c11_ v13_;
c15_: int = const 2;
v16_: int = mul c15_ v14_;
c17_: int = const 3;
v18_: int = mul c17_ v14_;
v19_: int = select v9_ v16_ v18_;
v20_: int = add c11_ v14_;
v21_: bool = lt v20_ v8_;
v4_: int = id v19_;
v5_: int = id v20_;
v15_: int = mul c12_ v14_;
c16_: int = const 3;
v17_: int = mul c16_ v14_;
v18_: int = select v9_ v15_ v17_;
v19_: int = add c12_ v13_;
v20_: bool = lt v19_ v8_;
v4_: int = id v18_;
v5_: int = id v19_;
v6_: int = id v6_;
v7_: int = id v7_;
v8_: int = id v8_;
v9_: bool = id v9_;
br v21_ .b10_ .b22_;
.b22_:
br v20_ .b10_ .b21_;
.b21_:
print v4_;
ret;
}
54 changes: 27 additions & 27 deletions tests/snapshots/files__branch_hoisting-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,36 @@ expression: visualization.result
v6_: int = id v0;
v7_: int = id c2_;
v8_: int = id c3_;
v9_: int = id c2_;
v9_: bool = eq c2_ v0;
v10_: int = id c2_;
v11_: int = id v0;
v12_: int = id c2_;
v13_: int = id c3_;
.b14_:
v15_: bool = eq v11_ v12_;
c16_: int = const 1;
v17_: int = add c16_ v10_;
v18_: int = add c16_ v17_;
v19_: int = add c16_ v18_;
c20_: int = const 2;
v21_: int = mul c20_ v19_;
c22_: int = const 3;
v23_: int = mul c22_ v19_;
v24_: int = select v15_ v21_ v23_;
v25_: int = add c16_ v19_;
v26_: bool = lt v25_ v13_;
v9_: int = id v24_;
v10_: int = id v25_;
v11_: int = id v11_;
v11_: int = id c2_;
v12_: int = id v0;
v13_: int = id c2_;
v14_: int = id c3_;
v15_: bool = id v9_;
.b16_:
c17_: int = const 2;
c18_: int = const 3;
v19_: int = add c18_ v11_;
v20_: int = mul c17_ v19_;
v21_: int = mul c18_ v19_;
v22_: int = select v15_ v20_ v21_;
c23_: int = const 4;
v24_: int = add c23_ v11_;
v25_: bool = lt v24_ v14_;
v10_: int = id v22_;
v11_: int = id v24_;
v12_: int = id v12_;
v13_: int = id v13_;
br v26_ .b14_ .b27_;
.b27_:
v4_: int = id v9_;
v5_: int = id v10_;
v6_: int = id v11_;
v7_: int = id v12_;
v8_: int = id v13_;
v14_: int = id v14_;
v15_: bool = id v15_;
br v25_ .b16_ .b26_;
.b26_:
v4_: int = id v10_;
v5_: int = id v11_;
v6_: int = id v0;
v7_: int = id c2_;
v8_: int = id c3_;
print v4_;
ret;
ret;
Expand Down
Loading

0 comments on commit 85bb784

Please sign in to comment.