From d5c5513b849b3de093bc550374ad120257671d03 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:23:54 +0000 Subject: [PATCH 1/9] Initial plan From 2fc16c7ec4caeb6e9f2489879136e56d90553f20 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:28:39 +0000 Subject: [PATCH 2/9] Initial analysis: Add plan to fix mutual recursion support Co-authored-by: nielstron <20638630+nielstron@users.noreply.github.com> --- opshin/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/opshin/__init__.py b/opshin/__init__.py index 8264bc26..43667659 100644 --- a/opshin/__init__.py +++ b/opshin/__init__.py @@ -7,7 +7,10 @@ import warnings import importlib.metadata -__version__ = importlib.metadata.version(__package__ or __name__) +try: + __version__ = importlib.metadata.version(__package__ or __name__) +except importlib.metadata.PackageNotFoundError: + __version__ = "0.24.4" __author__ = "nielstron" __author_email__ = "niels@opshin.dev" __copyright__ = "Copyright (C) 2025 nielstron" From adfd3fb60b7078f5ea6184b4604a1d0767ee75b7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:33:39 +0000 Subject: [PATCH 3/9] Implement two-pass type inference for mutual recursion Co-authored-by: nielstron <20638630+nielstron@users.noreply.github.com> --- opshin/type_inference.py | 20 +++++++++++++++++++- tests/test_misc.py | 27 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 7181f61e..94b1b2ae 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -66,7 +66,7 @@ FunctionType, ) -# from frozendict import frozendict +from frozendict import frozendict INITIAL_SCOPE = { @@ -1010,6 +1010,24 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: def visit_Module(self, node: Module) -> TypedModule: self.enter_scope() tm = copy(node) + + # First pass: extract all function signatures and add them to scope + # This enables mutual recursion by making all functions available before processing bodies + for stmt in node.body: + if isinstance(stmt, ast.FunctionDef): + # Create a minimal function type signature from the annotation + functyp = FunctionType( + frozenlist([ + InstanceType(self.type_from_annotation(arg.annotation)) + for arg in stmt.args.args + ]), + InstanceType(self.type_from_annotation(stmt.returns)), + bound_vars=frozendict(), # Will be updated in second pass + bind_self=None, # Will be updated if needed in second pass + ) + self.set_variable_type(stmt.name, InstanceType(functyp)) + + # Second pass: process all statements normally with function signatures available tm.body = self.visit_sequence(node.body) self.exit_scope() return tm diff --git a/tests/test_misc.py b/tests/test_misc.py index 3319600f..748e7e0f 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -283,6 +283,33 @@ def a(n: int) -> int: ret = eval_uplc_value(source_code, Unit()) self.assertEqual(100, ret) + def test_mutual_recursion_even_odd(self): + source_code = """ +def even(n: int) -> bool: + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int) -> bool: + if n == 0: + return False + else: + return even(n - 1) + +def validator(n: int) -> int: + if even(n): + return 1 + else: + return 0 + """ + # Test with even number + ret = eval_uplc_value(source_code, 4) + self.assertEqual(1, ret) + # Test with odd number + ret = eval_uplc_value(source_code, 3) + self.assertEqual(0, ret) + @unittest.expectedFailure def test_uninitialized_access(self): source_code = """ From 788524e6215e770caed9cfe4dad7e8aef399b80d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:36:08 +0000 Subject: [PATCH 4/9] Make mutual recursion type inference more robust with error handling Co-authored-by: nielstron <20638630+nielstron@users.noreply.github.com> --- opshin/type_inference.py | 27 ++++++++++++++++----------- tests/test_misc.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 94b1b2ae..a793794d 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -1015,17 +1015,22 @@ def visit_Module(self, node: Module) -> TypedModule: # This enables mutual recursion by making all functions available before processing bodies for stmt in node.body: if isinstance(stmt, ast.FunctionDef): - # Create a minimal function type signature from the annotation - functyp = FunctionType( - frozenlist([ - InstanceType(self.type_from_annotation(arg.annotation)) - for arg in stmt.args.args - ]), - InstanceType(self.type_from_annotation(stmt.returns)), - bound_vars=frozendict(), # Will be updated in second pass - bind_self=None, # Will be updated if needed in second pass - ) - self.set_variable_type(stmt.name, InstanceType(functyp)) + try: + # Create a minimal function type signature from the annotation + functyp = FunctionType( + frozenlist([ + InstanceType(self.type_from_annotation(arg.annotation)) + for arg in stmt.args.args + ]), + InstanceType(self.type_from_annotation(stmt.returns)), + bound_vars=frozendict(), # Will be updated in second pass + bind_self=None, # Will be updated if needed in second pass + ) + self.set_variable_type(stmt.name, InstanceType(functyp)) + except (TypeInferenceError, AttributeError): + # If type inference fails (e.g., due to forward reference to class), + # skip for now - the function will be properly typed in the second pass + pass # Second pass: process all statements normally with function signatures available tm.body = self.visit_sequence(node.body) diff --git a/tests/test_misc.py b/tests/test_misc.py index 748e7e0f..f4ede67e 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -310,6 +310,39 @@ def validator(n: int) -> int: ret = eval_uplc_value(source_code, 3) self.assertEqual(0, ret) + def test_mutual_recursion_three_way(self): + source_code = """ +def a(n: int) -> int: + if n <= 0: + return 1 + else: + return b(n - 1) + +def b(n: int) -> int: + if n <= 0: + return 2 + else: + return c(n - 1) + +def c(n: int) -> int: + if n <= 0: + return 3 + else: + return a(n - 1) + +def validator(n: int) -> int: + return a(n) + """ + # Test different values to verify the three-way recursion pattern + ret = eval_uplc_value(source_code, 0) + self.assertEqual(1, ret) # a(0) = 1 + ret = eval_uplc_value(source_code, 1) + self.assertEqual(2, ret) # a(1) = b(0) = 2 + ret = eval_uplc_value(source_code, 2) + self.assertEqual(3, ret) # a(2) = b(1) = c(0) = 3 + ret = eval_uplc_value(source_code, 3) + self.assertEqual(1, ret) # a(3) = b(2) = c(1) = a(0) = 1 + @unittest.expectedFailure def test_uninitialized_access(self): source_code = """ From 53c21b43c248aec9ec6c64fb7a5db4712b255b8b Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 18 Sep 2025 12:55:36 +0200 Subject: [PATCH 5/9] Attempt fix on type level It seems that the compiler does not correctly handle this though --- opshin/type_inference.py | 60 +++++++++++++++++++++++----------------- tests/test_Unions.py | 3 +- tests/test_misc.py | 22 ++++++++++++++- 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index a793794d..346e3e11 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -637,6 +637,7 @@ def type_from_annotation(self, ann: expr): raise NotImplementedError(f"Annotation type {ann.__class__} is not supported") def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: + # Zeroth pass: convert all class methods into functions with self as first argument additional_functions = [] for n in node_seq: if not isinstance(n, ast.ClassDef): @@ -682,6 +683,35 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: node_seq.extend(additional_functions) node_seq.append(last) + # First pass: extract all function signatures and add them to scope + # This enables mutual recursion by making all functions available before processing bodies + for stmt in node_seq: + if isinstance(stmt, ast.FunctionDef): + try: + # Create a minimal function type signature from the annotation + functyp = FunctionType( + frozenlist( + [ + InstanceType(self.type_from_annotation(arg.annotation)) + for arg in stmt.args.args + ] + ), + InstanceType(self.type_from_annotation(stmt.returns)), + bound_vars=frozendict(), # Will be updated in second pass + bind_self=None, # Will be updated if needed in second pass + ) + self.set_variable_type(stmt.name, InstanceType(functyp)) + except (TypeInferenceError, AttributeError): + # If type inference fails (e.g., due to forward reference to class), + # skip for now - the function will be properly typed in the second pass + pass + if isinstance(stmt, ast.ClassDef): + # Classes need to be added to the scope to ensure they can be used in annotations + class_record = RecordReader(self).extract(stmt) + typ = RecordType(class_record) + self.set_variable_type(stmt.name, typ) + + # Second pass: process all statements normally with function signatures available stmts = [] prevtyps = {} for n in node_seq: @@ -700,7 +730,8 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: def visit_ClassDef(self, node: ClassDef) -> TypedClassDef: class_record = RecordReader(self).extract(node) typ = RecordType(class_record) - self.set_variable_type(node.name, typ) + # Set the class type in the current scope --> already done in first pass in body + # self.set_variable_type(node.name, typ) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [ typedarg(arg=field, typ=field_typ, orig_arg=field) for field, field_typ in class_record.fields @@ -1002,37 +1033,14 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: rets_extractor.check_fulfills(tfd) self.exit_scope() - # We need the function type outside for usage - self.set_variable_type(node.name, tfd.typ) + # We need the function type outside for usage --> already done in first pass + # self.set_variable_type(node.name, tfd.typ) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args return tfd def visit_Module(self, node: Module) -> TypedModule: self.enter_scope() tm = copy(node) - - # First pass: extract all function signatures and add them to scope - # This enables mutual recursion by making all functions available before processing bodies - for stmt in node.body: - if isinstance(stmt, ast.FunctionDef): - try: - # Create a minimal function type signature from the annotation - functyp = FunctionType( - frozenlist([ - InstanceType(self.type_from_annotation(arg.annotation)) - for arg in stmt.args.args - ]), - InstanceType(self.type_from_annotation(stmt.returns)), - bound_vars=frozendict(), # Will be updated in second pass - bind_self=None, # Will be updated if needed in second pass - ) - self.set_variable_type(stmt.name, InstanceType(functyp)) - except (TypeInferenceError, AttributeError): - # If type inference fails (e.g., due to forward reference to class), - # skip for now - the function will be properly typed in the second pass - pass - - # Second pass: process all statements normally with function signatures available tm.body = self.visit_sequence(node.body) self.exit_scope() return tm diff --git a/tests/test_Unions.py b/tests/test_Unions.py index 5f9dea57..33eca585 100644 --- a/tests/test_Unions.py +++ b/tests/test_Unions.py @@ -1,7 +1,7 @@ import unittest import hypothesis import pytest -from hypothesis import given +from hypothesis import given, example from hypothesis import strategies as st from opshin import builder from .utils import eval_uplc_value, eval_uplc, eval_uplc_raw @@ -398,6 +398,7 @@ def validator(x: int) -> int: self.assertEqual(res, real) @hypothesis.given(st.sampled_from(range(14))) + @example(2) def test_Union_cast_ifexpr(self, x): source_code = """ from dataclasses import dataclass diff --git a/tests/test_misc.py b/tests/test_misc.py index f4ede67e..1becfd44 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -306,7 +306,7 @@ def validator(n: int) -> int: # Test with even number ret = eval_uplc_value(source_code, 4) self.assertEqual(1, ret) - # Test with odd number + # Test with odd number ret = eval_uplc_value(source_code, 3) self.assertEqual(0, ret) @@ -2839,3 +2839,23 @@ def validator(a: int) -> int: assert "int" in str(e) and "str" in str( e ), "Type check did not fail with correct message" + + def test_mutual_recursion(self): + source_code = """ +def even(n: int) -> bool: + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int) -> bool: + if n == 0: + return False + else: + return even(n - 1) + +def validator(a: int) -> int: + return 42 if even(a) else 0 +""" + res = eval_uplc_value(source_code, 2) + self.assertEqual(res, 42) From 8705e2e40b5b71ba5c81e99823ab66fb0b8777a8 Mon Sep 17 00:00:00 2001 From: Niels Date: Thu, 18 Sep 2025 13:10:15 +0200 Subject: [PATCH 6/9] Fix --- opshin/type_inference.py | 18 ++++++++++++------ opshin/util.py | 8 ++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index 346e3e11..cf2669d5 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -697,8 +697,12 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST: ] ), InstanceType(self.type_from_annotation(stmt.returns)), - bound_vars=frozendict(), # Will be updated in second pass - bind_self=None, # Will be updated if needed in second pass + bound_vars={ + v: InstanceType(AnyType()) + for v in externally_bound_vars(stmt) + if not v in ["List", "Dict"] + }, + bind_self=stmt.name, ) self.set_variable_type(stmt.name, InstanceType(functyp)) except (TypeInferenceError, AttributeError): @@ -731,7 +735,7 @@ def visit_ClassDef(self, node: ClassDef) -> TypedClassDef: class_record = RecordReader(self).extract(node) typ = RecordType(class_record) # Set the class type in the current scope --> already done in first pass in body - # self.set_variable_type(node.name, typ) + self.set_variable_type(node.name, typ, force=True) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [ typedarg(arg=field, typ=field_typ, orig_arg=field) for field, field_typ in class_record.fields @@ -1008,7 +1012,9 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: for v in externally_bound_vars(node) if not v in ["List", "Dict"] }, - bind_self=node.name if node.name in read_vars(node) else None, + # this used to check whether the function recurses. + # but the function might co-recurse with another function, so we always bind self + bind_self=node.name, ) tfd.typ = InstanceType(functyp) if wraps_builtin: @@ -1033,8 +1039,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef: rets_extractor.check_fulfills(tfd) self.exit_scope() - # We need the function type outside for usage --> already done in first pass - # self.set_variable_type(node.name, tfd.typ) + # We need the function type outside for usage --> already done in first pass, but needs an update + self.set_variable_type(node.name, tfd.typ, force=True) self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args return tfd diff --git a/opshin/util.py b/opshin/util.py index 2af25a78..3d3db3cf 100644 --- a/opshin/util.py +++ b/opshin/util.py @@ -256,7 +256,7 @@ class NameReadCollector(CompilingNodeVisitor): step = "Collecting variables that are read" def __init__(self): - self.read = defaultdict(int) + self.read: typing.DefaultDict[str, int] = defaultdict(int) def visit_AnnAssign(self, node) -> None: # ignore annotations of variables @@ -277,7 +277,7 @@ def visit_ClassDef(self, node: ClassDef): pass -def read_vars(node): +def read_vars(node) -> typing.List[str]: """ Returns all variable names read to in this node """ @@ -286,11 +286,11 @@ def read_vars(node): return sorted(collector.read.keys()) -def all_vars(node): +def all_vars(node) -> typing.List[str]: return sorted(set(read_vars(node) + written_vars(node))) -def externally_bound_vars(node: FunctionDef): +def externally_bound_vars(node: FunctionDef) -> typing.List[str]: """A superset of the variables bound from an outer scope""" return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"}) From f7d9d20dcae10a0ef3a0152eb8e1a8079fca6716 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 21 Sep 2025 09:51:57 +0000 Subject: [PATCH 7/9] Address reviewer feedback: Remove unrelated __init__.py change and add test cases for nested functions and classes Co-authored-by: nielstron <20638630+nielstron@users.noreply.github.com> --- opshin/__init__.py | 5 +---- tests/test_misc.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/opshin/__init__.py b/opshin/__init__.py index 43667659..8264bc26 100644 --- a/opshin/__init__.py +++ b/opshin/__init__.py @@ -7,10 +7,7 @@ import warnings import importlib.metadata -try: - __version__ = importlib.metadata.version(__package__ or __name__) -except importlib.metadata.PackageNotFoundError: - __version__ = "0.24.4" +__version__ = importlib.metadata.version(__package__ or __name__) __author__ = "nielstron" __author_email__ = "niels@opshin.dev" __copyright__ = "Copyright (C) 2025 nielstron" diff --git a/tests/test_misc.py b/tests/test_misc.py index 1becfd44..de1b4a51 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -343,6 +343,62 @@ def validator(n: int) -> int: ret = eval_uplc_value(source_code, 3) self.assertEqual(1, ret) # a(3) = b(2) = c(1) = a(0) = 1 + def test_mutual_recursion_nested_functions(self): + source_code = """ +def validator(n: int) -> int: + def even_nested(x: int) -> bool: + if x == 0: + return True + else: + return odd_nested(x - 1) + + def odd_nested(x: int) -> bool: + if x == 0: + return False + else: + return even_nested(x - 1) + + if even_nested(n): + return 1 + else: + return 0 + """ + # Test nested mutual recursion + ret = eval_uplc_value(source_code, 4) + self.assertEqual(1, ret) # 4 is even + ret = eval_uplc_value(source_code, 3) + self.assertEqual(0, ret) # 3 is odd + + def test_mutual_recursion_with_classes(self): + source_code = """ +from typing import Dict, List, Union +from pycardano import Datum as Anything, PlutusData +from dataclasses import dataclass + +def process_data(x: int) -> int: + obj = MyData(x) + return transform_data(obj) + +@dataclass +class MyData(PlutusData): + CONSTR_ID = 0 + value: int + +def transform_data(data: MyData) -> int: + if data.value <= 0: + return 0 + else: + return process_data(data.value - 1) + 1 + +def validator(n: int) -> int: + return process_data(n) + """ + # Test mutual recursion with classes defined between functions + ret = eval_uplc_value(source_code, 3) + self.assertEqual(3, ret) # process_data(3) should return 3 + ret = eval_uplc_value(source_code, 0) + self.assertEqual(0, ret) # process_data(0) should return 0 + @unittest.expectedFailure def test_uninitialized_access(self): source_code = """ From c5e1657ebcf0886ce49ea1d94a6a66d81ab65ab4 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 23 Sep 2025 13:28:33 +0200 Subject: [PATCH 8/9] Fix recursion --- opshin/type_inference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/opshin/type_inference.py b/opshin/type_inference.py index cf2669d5..26f70daf 100644 --- a/opshin/type_inference.py +++ b/opshin/type_inference.py @@ -1670,6 +1670,11 @@ def visit_While(self, node: For) -> bool: # the else path is always visited return self.visit_sequence(node.orelse) + def visit_FunctionDef(self, node: FunctionDef) -> bool: + # Skip visiting nested function definitions - they should be handled by their own ReturnExtractor + # when they are processed separately during type inference + return False + def visit_Return(self, node: Return) -> bool: assert ( self.func_rettyp >= node.typ From 38fb1cf90b58f2b7ed494f1edc3ab5849ebd0679 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 23 Sep 2025 13:35:24 +0200 Subject: [PATCH 9/9] Add a seperate test for forward passing --- tests/test_misc.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/test_misc.py b/tests/test_misc.py index de1b4a51..7b9cd570 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -399,6 +399,74 @@ def validator(n: int) -> int: ret = eval_uplc_value(source_code, 0) self.assertEqual(0, ret) # process_data(0) should return 0 + def test_three_function_chain_depth_issue(self): + # Test for known runtime issue with function call chaining at depth >= 2 + # This is separate from mutual recursion and affects any three-function chain + source_code = """ +def a(n: int) -> int: + if n <= 0: + return 1 + else: + return b(n - 1) + +def b(n: int) -> int: + if n <= 0: + return 2 + else: + return c(n - 1) + +def c(n: int) -> int: + if n <= 0: + return 3 + else: + return 42 # No recursion, just return a constant + +def validator(n: int) -> int: + return a(n) + """ + # This should work for n=0,1 but fails at n=2 with runtime error + ret = eval_uplc_value(source_code, 0) + self.assertEqual(1, ret) + ret = eval_uplc_value(source_code, 1) + self.assertEqual(2, ret) + # This line causes the runtime error + ret = eval_uplc_value(source_code, 42) + self.assertEqual(42, ret) + + def test_three_function_chain_depth_issue_reverted(self): + # same as previous test but with function order reverted + # to ensure the issue is with the chain depth and not function order + source_code = """ +def c(n: int) -> int: + if n <= 0: + return 3 + else: + return 42 # No recursion, just return a constant + +def b(n: int) -> int: + if n <= 0: + return 2 + else: + return c(n - 1) + +def a(n: int) -> int: + if n <= 0: + return 1 + else: + return b(n - 1) + +def validator(n: int) -> int: + return a(n) + """ + # This should work for n=0,1 but fails at n=2 with runtime error + ret = eval_uplc_value(source_code, 0) + self.assertEqual(1, ret) + ret = eval_uplc_value(source_code, 1) + self.assertEqual(2, ret) + # This line causes the runtime error + ret = eval_uplc_value(source_code, 42) + self.assertEqual(42, ret) + @unittest.expectedFailure def test_uninitialized_access(self): source_code = """