Skip to content

Fix bug that ignored defs in if when missing from else #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions src/kirin/dialects/scf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,18 @@ def lower_If(self, state: lowering.State, node: ast.If) -> lowering.Result:
yield_names: list[str] = []
body_yields: list[ir.SSAValue] = []
else_yields: list[ir.SSAValue] = []
if node.orelse:
for name in body_frame.defs.keys():
if name in else_frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
else_yields.append(else_frame[name])
else:
for name in body_frame.defs.keys():
if name in frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
value = frame.get(name)
if value is None:
raise lowering.BuildError(f"expected value for {name}")
else_yields.append(value)
for name in body_frame.defs.keys():
if name in else_frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
else_yields.append(else_frame[name])
elif name in frame.defs:
yield_names.append(name)
body_yields.append(body_frame[name])
value = frame.get(name)
if value is None:
raise lowering.BuildError(f"expected value for {name}")
else_yields.append(value)

if not (
body_frame.curr_block.last_stmt
Expand Down
82 changes: 64 additions & 18 deletions test/dialects/scf/test_ifelse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kirin import ir
from kirin.passes import Fold
from kirin.prelude import python_basic
from kirin.dialects import scf, func, lowering

Expand All @@ -22,24 +23,69 @@ def run_pass(method):
return run_pass


@kernel
def main(x):
if x > 0:
y = x + 1
z = y + 1
return z
else:
y = x + 2
z = y + 2
def test_basic_if_else():
@kernel
def main(x):
if x > 0:
y = x + 1
z = y + 1
return z
else:
y = x + 2
z = y + 2

if x < 0:
y = y + 3
z = y + 3
else:
y = x + 4
z = y + 4
return y, z
if x < 0:
y = y + 3
z = y + 3
else:
y = x + 4
z = y + 4
return y, z

main.print()
print(main(1))

main.print()
# print(main(1))

def test_if_else_defs():

@kernel
def main(n: int):
x = 0

if x == n:
x = 1
else:
y = 2 # noqa: F841

return x

main.print()

# make sure fold doesn't remove the nested def
main2 = main.similar(kernel)
Fold(main2.dialects)(main2)

main2.print()

@kernel
def main_elif(n: int):
x = 0

if x == n:
x = 3
elif x == n + 1:
x = 4

return x

main_elif.print()

main_elif2 = main_elif.similar(kernel)
Fold(main_elif2.dialects)(main_elif2)

main_elif2.print()

assert main_elif2.code.is_equal(main2.code)


# test_if_else_defs()
Loading