diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 7181f61e..26f70daf 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -66,7 +66,7 @@ FunctionType, ) -# from frozendict import frozendict +from frozendict import frozendict INITIAL_SCOPE = { @@ -637,6 +637,7 @@ def type_from_annotation(self, ann: expr): raise NotImplementedError(f"Annotation type {ann.__class__} is not supported") def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: + # Zeroth pass: convert all class methods into functions with self as first argument additional_functions = [] for n in node_seq: if not isinstance(n, ast.ClassDef): @@ -682,6 +683,39 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: node_seq.extend(additional_functions) node_seq.append(last) + # First pass: extract all function signatures and add them to scope + # This enables mutual recursion by making all functions available before processing bodies + for stmt in node_seq: + if isinstance(stmt, ast.FunctionDef): + try: + # Create a minimal function type signature from the annotation + functyp = FunctionType( + frozenlist( + [ + InstanceType(self.type_from_annotation(arg.annotation)) + for arg in stmt.args.args + ] + ), + InstanceType(self.type_from_annotation(stmt.returns)), + bound_vars={ + v: InstanceType(AnyType()) + for v in externally_bound_vars(stmt) + if not v in ["List", "Dict"] + }, + bind_self=stmt.name, + ) + self.set_variable_type(stmt.name, InstanceType(functyp)) + except (TypeInferenceError, AttributeError): + # If type inference fails (e.g., due to forward reference to class), + # skip for now - the function will be properly typed in the second pass + pass + if isinstance(stmt, ast.ClassDef): + # Classes need to be added to the scope to ensure they can be used in annotations + class_record = RecordReader(self).extract(stmt) + typ = RecordType(class_record) + self.set_variable_type(stmt.name, typ) + + # Second pass: process all statements normally with function signatures available stmts = [] prevtyps = {} for n in node_seq: @@ -700,7 +734,8 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: def visit_ClassDef(self, node: ClassDef) -> TypedClassDef: class_record = RecordReader(self).extract(node) typ = RecordType(class_record) - self.set_variable_type(node.name, typ) + # Set the class type in the current scope --> already done in first pass in body + self.set_variable_type(node.name, typ, force=True) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [ typedarg(arg=field, typ=field_typ, orig_arg=field) for field, field_typ in class_record.fields @@ -977,7 +1012,9 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: for v in externally_bound_vars(node) if not v in ["List", "Dict"] }, - bind_self=node.name if node.name in read_vars(node) else None, + # this used to check whether the function recurses. + # but the function might co-recurse with another function, so we always bind self + bind_self=node.name, ) tfd.typ = InstanceType(functyp) if wraps_builtin: @@ -1002,8 +1039,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: rets_extractor.check_fulfills(tfd) self.exit_scope() - # We need the function type outside for usage - self.set_variable_type(node.name, tfd.typ) + # We need the function type outside for usage --> already done in first pass, but needs an update + self.set_variable_type(node.name, tfd.typ, force=True) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args return tfd @@ -1633,6 +1670,11 @@ def visit_While(self, node: For) -> bool: # the else path is always visited return self.visit_sequence(node.orelse) + def visit_FunctionDef(self, node: FunctionDef) -> bool: + # Skip visiting nested function definitions - they should be handled by their own ReturnExtractor + # when they are processed separately during type inference + return False + def visit_Return(self, node: Return) -> bool: assert ( self.func_rettyp >= node.typ diff --git a/opshin/util.py b/opshin/util.py index 2af25a78..3d3db3cf 100644 --- a/opshin/util.py +++ b/opshin/util.py @@ -256,7 +256,7 @@ class NameReadCollector(CompilingNodeVisitor): step = "Collecting variables that are read" def __init__(self): - self.read = defaultdict(int) + self.read: typing.DefaultDict[str, int] = defaultdict(int) def visit_AnnAssign(self, node) -> None: # ignore annotations of variables @@ -277,7 +277,7 @@ def visit_ClassDef(self, node: ClassDef): pass -def read_vars(node): +def read_vars(node) -> typing.List[str]: """ Returns all variable names read to in this node """ @@ -286,11 +286,11 @@ def read_vars(node): return sorted(collector.read.keys()) -def all_vars(node): +def all_vars(node) -> typing.List[str]: return sorted(set(read_vars(node) + written_vars(node))) -def externally_bound_vars(node: FunctionDef): +def externally_bound_vars(node: FunctionDef) -> typing.List[str]: """A superset of the variables bound from an outer scope""" return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"}) diff --git a/tests/test_Unions.py b/tests/test_Unions.py index 5f9dea57..33eca585 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -1,7 +1,7 @@ import unittest import hypothesis import pytest -from hypothesis import given +from hypothesis import given, example from hypothesis import strategies as st from opshin import builder from .utils import eval_uplc_value, eval_uplc, eval_uplc_raw @@ -398,6 +398,7 @@ def validator(x: int) -> int: self.assertEqual(res, real) @hypothesis.given(st.sampled_from(range(14))) + @example(2) def test_Union_cast_ifexpr(self, x): source_code = """ from dataclasses import dataclass diff --git a/tests/test_misc.py b/tests/test_misc.py index 3319600f..7b9cd570 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -283,6 +283,190 @@ def a(n: int) -> int: ret = eval_uplc_value(source_code, Unit()) self.assertEqual(100, ret) + def test_mutual_recursion_even_odd(self): + source_code = """ +def even(n: int) -> bool: + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int) -> bool: + if n == 0: + return False + else: + return even(n - 1) + +def validator(n: int) -> int: + if even(n): + return 1 + else: + return 0 + """ + # Test with even number + ret = eval_uplc_value(source_code, 4) + self.assertEqual(1, ret) + # Test with odd number + ret = eval_uplc_value(source_code, 3) + self.assertEqual(0, ret) + + def test_mutual_recursion_three_way(self): + source_code = """ +def a(n: int) -> int: + if n <= 0: + return 1 + else: + return b(n - 1) + +def b(n: int) -> int: + if n <= 0: + return 2 + else: + return c(n - 1) + +def c(n: int) -> int: + if n <= 0: + return 3 + else: + return a(n - 1) + +def validator(n: int) -> int: + return a(n) + """ + # Test different values to verify the three-way recursion pattern + ret = eval_uplc_value(source_code, 0) + self.assertEqual(1, ret) # a(0) = 1 + ret = eval_uplc_value(source_code, 1) + self.assertEqual(2, ret) # a(1) = b(0) = 2 + ret = eval_uplc_value(source_code, 2) + self.assertEqual(3, ret) # a(2) = b(1) = c(0) = 3 + ret = eval_uplc_value(source_code, 3) + self.assertEqual(1, ret) # a(3) = b(2) = c(1) = a(0) = 1 + + def test_mutual_recursion_nested_functions(self): + source_code = """ +def validator(n: int) -> int: + def even_nested(x: int) -> bool: + if x == 0: + return True + else: + return odd_nested(x - 1) + + def odd_nested(x: int) -> bool: + if x == 0: + return False + else: + return even_nested(x - 1) + + if even_nested(n): + return 1 + else: + return 0 + """ + # Test nested mutual recursion + ret = eval_uplc_value(source_code, 4) + self.assertEqual(1, ret) # 4 is even + ret = eval_uplc_value(source_code, 3) + self.assertEqual(0, ret) # 3 is odd + + def test_mutual_recursion_with_classes(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def process_data(x: int) -> int: + obj = MyData(x) + return transform_data(obj) + +@dataclass +class MyData(PlutusData): + CONSTR_ID = 0 + value: int + +def transform_data(data: MyData) -> int: + if data.value <= 0: + return 0 + else: + return process_data(data.value - 1) + 1 + +def validator(n: int) -> int: + return process_data(n) + """ + # Test mutual recursion with classes defined between functions + ret = eval_uplc_value(source_code, 3) + self.assertEqual(3, ret) # process_data(3) should return 3 + ret = eval_uplc_value(source_code, 0) + self.assertEqual(0, ret) # process_data(0) should return 0 + + def test_three_function_chain_depth_issue(self): + # Test for known runtime issue with function call chaining at depth >= 2 + # This is separate from mutual recursion and affects any three-function chain + source_code = """ +def a(n: int) -> int: + if n <= 0: + return 1 + else: + return b(n - 1) + +def b(n: int) -> int: + if n <= 0: + return 2 + else: + return c(n - 1) + +def c(n: int) -> int: + if n <= 0: + return 3 + else: + return 42 # No recursion, just return a constant + +def validator(n: int) -> int: + return a(n) + """ + # This should work for n=0,1 but fails at n=2 with runtime error + ret = eval_uplc_value(source_code, 0) + self.assertEqual(1, ret) + ret = eval_uplc_value(source_code, 1) + self.assertEqual(2, ret) + # This line causes the runtime error + ret = eval_uplc_value(source_code, 42) + self.assertEqual(42, ret) + + def test_three_function_chain_depth_issue_reverted(self): + # same as previous test but with function order reverted + # to ensure the issue is with the chain depth and not function order + source_code = """ +def c(n: int) -> int: + if n <= 0: + return 3 + else: + return 42 # No recursion, just return a constant + +def b(n: int) -> int: + if n <= 0: + return 2 + else: + return c(n - 1) + +def a(n: int) -> int: + if n <= 0: + return 1 + else: + return b(n - 1) + +def validator(n: int) -> int: + return a(n) + """ + # This should work for n=0,1 but fails at n=2 with runtime error + ret = eval_uplc_value(source_code, 0) + self.assertEqual(1, ret) + ret = eval_uplc_value(source_code, 1) + self.assertEqual(2, ret) + # This line causes the runtime error + ret = eval_uplc_value(source_code, 42) + self.assertEqual(42, ret) + @unittest.expectedFailure def test_uninitialized_access(self): source_code = """ @@ -2779,3 +2963,23 @@ def validator(a: int) -> int: assert "int" in str(e) and "str" in str( e ), "Type check did not fail with correct message" + + def test_mutual_recursion(self): + source_code = """ +def even(n: int) -> bool: + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int) -> bool: + if n == 0: + return False + else: + return even(n - 1) + +def validator(a: int) -> int: + return 42 if even(a) else 0 +""" + res = eval_uplc_value(source_code, 2) + self.assertEqual(res, 42)