-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow divide_dim to divide to non-literal expressions #628
Changes from 3 commits
4dda8e0
1646ef9
0bd8431
7bea646
894de93
df33a58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,6 +264,24 @@ def Check_CompareExprs(proc, stmts, lhs, op, rhs): | |
Check_ExprBound(proc, stmts, expr, op, 0) | ||
|
||
|
||
def Check_IsDivisible(proc, stmts, expr, quot): | ||
failed = False | ||
if not isinstance(expr, LoopIR.Const): | ||
try: | ||
quot = LoopIR.Const(quot, T.int, null_srcinfo()) | ||
expr_mod_quot = LoopIR.BinOp("%", expr, quot, T.index, null_srcinfo()) | ||
zero = LoopIR.Const(0, T.int, null_srcinfo()) | ||
Check_CompareExprs(proc, stmts, expr_mod_quot, "==", zero) | ||
except SchedulingError: | ||
failed = True | ||
else: | ||
# Fast path | ||
failed = expr.val % quot != 0 | ||
|
||
if failed: | ||
raise SchedulingError(f"cannot perfectly divide '{expr}' by {quot}") | ||
|
||
|
||
def extract_env(c: ic.Cursor) -> List[Tuple[Sym, ic.Cursor]]: | ||
""" | ||
Extract the environment of live variables at `c`. | ||
|
@@ -303,6 +321,28 @@ def move_back(c): | |
return c.prev() | ||
|
||
|
||
# --------------------------------------------------------------------------- # | ||
# --------------------------------------------------------------------------- # | ||
# IR Building Helpers | ||
|
||
|
||
def divide_expr(e, quot): | ||
assert isinstance(e, LoopIR.expr) | ||
if isinstance(quot, int): | ||
quot_int = quot | ||
quot_ir = LoopIR.Const(quot, e.type, e.srcinfo) | ||
elif isinstance(quot, LoopIR.expr): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! |
||
quot_int = quot.val | ||
quot_ir = quot | ||
else: | ||
assert False, f"Bad case {type(quot)}" | ||
if isinstance(e, LoopIR.Const) and e.val % quot == 0: | ||
div = LoopIR.Const(e.val // quot_int, e.type, e.srcinfo) | ||
else: | ||
div = LoopIR.BinOp("/", e, quot_ir, e.type, e.srcinfo) | ||
return div | ||
|
||
|
||
# --------------------------------------------------------------------------- # | ||
# --------------------------------------------------------------------------- # | ||
# Scheduling directives | ||
|
@@ -728,25 +768,10 @@ def ceildiv(lhs, rhs): | |
elif tail_strategy in ["cut", "cut_and_guard"]: | ||
outer_hi = szop("/", N, inner_hi) # floor div | ||
elif tail_strategy == "perfect": | ||
if not isinstance(N, LoopIR.Const): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great! |
||
hi_mod_quot = boolop("%", N, cnst(quot), T.index) | ||
try: | ||
ir = loop_cursor.get_root() | ||
loop = loop_cursor._node | ||
Check_CompareExprs(ir, [loop], hi_mod_quot, "==", cnst(0)) | ||
except SchedulingError: | ||
raise SchedulingError( | ||
f"cannot perfectly split the '{loop.iter}' loop " f"by {quot}" | ||
) | ||
outer_hi = boolop("/", N, cnst(quot), T.index) | ||
else: | ||
if N.val % quot != 0: | ||
raise SchedulingError( | ||
f"cannot perfectly split the '{loop.iter}' loop " | ||
f"because {quot} does not evenly divide " | ||
f"{N.val}" | ||
) | ||
outer_hi = cnst(N.val // quot) | ||
ir = loop_cursor.get_root() | ||
loop = loop_cursor._node | ||
Check_IsDivisible(ir, [loop], N, quot) | ||
outer_hi = divide_expr(N, quot) | ||
else: | ||
assert False, f"bad tail strategy: {tail_strategy}" | ||
|
||
|
@@ -1728,17 +1753,13 @@ def DoDivideDim(alloc_cursor, dim_idx, quotient): | |
old_typ = alloc_s.type | ||
old_shp = old_typ.shape() | ||
dim = old_shp[dim_idx] | ||
if not isinstance(dim, LoopIR.Const): | ||
raise SchedulingError(f"Cannot divide non-literal dimension: {dim}") | ||
if not dim.val % quotient == 0: | ||
raise SchedulingError(f"Cannot divide {dim.val} evenly by {quotient}") | ||
denom = quotient | ||
numer = dim.val // denom | ||
Check_IsDivisible(alloc_cursor.get_root(), [alloc_s], dim, quotient) | ||
numer = divide_expr(dim, quotient) | ||
new_shp = ( | ||
old_shp[:dim_idx] | ||
+ [ | ||
LoopIR.Const(numer, T.int, dim.srcinfo), | ||
LoopIR.Const(denom, T.int, dim.srcinfo), | ||
numer, | ||
LoopIR.Const(quotient, T.int, dim.srcinfo), | ||
] | ||
+ old_shp[dim_idx + 1 :] | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
def foo(n: size, m: size, A: R[n + m + 12] @ DRAM): | ||
x: R[n, 3 * m, 4, m] @ DRAM | ||
for i in seq(0, n): | ||
for j in seq(0, 12): | ||
for k in seq(0, m): | ||
x[i, j / 4, j % 4, k] = A[i + j + k] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
def foo(n: size, m: size): | ||
x: R[n, 1, (7 + m) / 8 * 8 / 8, 8, 1, m, 1] @ DRAM | ||
for i in seq(0, n): | ||
for j in seq(0, m): | ||
for k in seq(0, m): | ||
x[i, 0, j / 8, j % 8, 0, k, 0] = 2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change the docstring above "This limited implementation of
divide_dim
requires that the dimension being divided is constant itself." ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!