diff --git a/src/kirin/dialects/scf/absint.py b/src/kirin/dialects/scf/absint.py index 9987addf9..4447b9b84 100644 --- a/src/kirin/dialects/scf/absint.py +++ b/src/kirin/dialects/scf/absint.py @@ -62,4 +62,10 @@ def _infer_if_else_cond( if isinstance(body_term, func.Return): frame.worklist.append(interp.Successor(body_block, frame.get(stmt.cond))) return - return interp_.frame_call_region(frame, stmt, body, frame.get(stmt.cond)) + + with interp_.new_frame(stmt, has_parent_access=True) as body_frame: + ret = interp_.frame_call_region( + body_frame, stmt, body, frame.get(stmt.cond) + ) + frame.entries.update(body_frame.entries) + return ret diff --git a/test/analysis/dataflow/typeinfer/test_inter_method.py b/test/analysis/dataflow/typeinfer/test_inter_method.py index 8745dd1fa..cfa1a48c8 100644 --- a/test/analysis/dataflow/typeinfer/test_inter_method.py +++ b/test/analysis/dataflow/typeinfer/test_inter_method.py @@ -10,12 +10,12 @@ def foo(x: int): return x - 1.0 -@basic(typeinfer=True) +@basic(typeinfer=True, no_raise=False) def main(x: int): return foo(x) -@basic(typeinfer=True) +@basic(typeinfer=True, no_raise=False) def moo(x): return foo(x) @@ -28,3 +28,18 @@ def test_inter_method_infer(): assert foo.arg_types[0] == types.Int assert foo.inferred is False assert foo.return_type is types.Any + + +def test_infer_if_return(): + from kirin.prelude import structural + + @structural(typeinfer=True, fold=True, no_raise=False) + def test(b: bool): + if b: + return False + else: + b = not b + + return b + + test.print()