Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
FunctionType,
)

# from frozendict import frozendict
from frozendict import frozendict


INITIAL_SCOPE = {
Expand Down Expand Up @@ -637,6 +637,7 @@ def type_from_annotation(self, ann: expr):
raise NotImplementedError(f"Annotation type {ann.__class__} is not supported")

def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST:
# Zeroth pass: convert all class methods into functions with self as first argument
additional_functions = []
for n in node_seq:
if not isinstance(n, ast.ClassDef):
Expand Down Expand Up @@ -682,6 +683,39 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST:
node_seq.extend(additional_functions)
node_seq.append(last)

# First pass: extract all function signatures and add them to scope
# This enables mutual recursion by making all functions available before processing bodies
for stmt in node_seq:
if isinstance(stmt, ast.FunctionDef):
try:
# Create a minimal function type signature from the annotation
functyp = FunctionType(
frozenlist(
[
InstanceType(self.type_from_annotation(arg.annotation))
for arg in stmt.args.args
]
),
InstanceType(self.type_from_annotation(stmt.returns)),
bound_vars={
v: InstanceType(AnyType())
for v in externally_bound_vars(stmt)
if not v in ["List", "Dict"]
},
bind_self=stmt.name,
)
self.set_variable_type(stmt.name, InstanceType(functyp))
except (TypeInferenceError, AttributeError):
# If type inference fails (e.g., due to forward reference to class),
# skip for now - the function will be properly typed in the second pass
pass
if isinstance(stmt, ast.ClassDef):
# Classes need to be added to the scope to ensure they can be used in annotations
class_record = RecordReader(self).extract(stmt)
typ = RecordType(class_record)
self.set_variable_type(stmt.name, typ)

# Second pass: process all statements normally with function signatures available
stmts = []
prevtyps = {}
for n in node_seq:
Expand All @@ -700,7 +734,8 @@ def visit_sequence(self, node_seq: typing.List[stmt]) -> plt.AST:
def visit_ClassDef(self, node: ClassDef) -> TypedClassDef:
class_record = RecordReader(self).extract(node)
typ = RecordType(class_record)
self.set_variable_type(node.name, typ)
# Set the class type in the current scope --> already done in first pass in body
self.set_variable_type(node.name, typ, force=True)
self.FUNCTION_ARGUMENT_REGISTRY[node.name] = [
typedarg(arg=field, typ=field_typ, orig_arg=field)
for field, field_typ in class_record.fields
Expand Down Expand Up @@ -977,7 +1012,9 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
for v in externally_bound_vars(node)
if not v in ["List", "Dict"]
},
bind_self=node.name if node.name in read_vars(node) else None,
# this used to check whether the function recurses.
# but the function might co-recurse with another function, so we always bind self
bind_self=node.name,
)
tfd.typ = InstanceType(functyp)
if wraps_builtin:
Expand All @@ -1002,8 +1039,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> TypedFunctionDef:
rets_extractor.check_fulfills(tfd)

self.exit_scope()
# We need the function type outside for usage
self.set_variable_type(node.name, tfd.typ)
# We need the function type outside for usage --> already done in first pass, but needs an update
self.set_variable_type(node.name, tfd.typ, force=True)
self.FUNCTION_ARGUMENT_REGISTRY[node.name] = node.args.args
return tfd

Expand Down Expand Up @@ -1633,6 +1670,11 @@ def visit_While(self, node: For) -> bool:
# the else path is always visited
return self.visit_sequence(node.orelse)

def visit_FunctionDef(self, node: FunctionDef) -> bool:
# Skip visiting nested function definitions - they should be handled by their own ReturnExtractor
# when they are processed separately during type inference
return False

def visit_Return(self, node: Return) -> bool:
assert (
self.func_rettyp >= node.typ
Expand Down
8 changes: 4 additions & 4 deletions opshin/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ class NameReadCollector(CompilingNodeVisitor):
step = "Collecting variables that are read"

def __init__(self):
self.read = defaultdict(int)
self.read: typing.DefaultDict[str, int] = defaultdict(int)

def visit_AnnAssign(self, node) -> None:
# ignore annotations of variables
Expand All @@ -277,7 +277,7 @@ def visit_ClassDef(self, node: ClassDef):
pass


def read_vars(node):
def read_vars(node) -> typing.List[str]:
"""
Returns all variable names read to in this node
"""
Expand All @@ -286,11 +286,11 @@ def read_vars(node):
return sorted(collector.read.keys())


def all_vars(node):
def all_vars(node) -> typing.List[str]:
return sorted(set(read_vars(node) + written_vars(node)))


def externally_bound_vars(node: FunctionDef):
def externally_bound_vars(node: FunctionDef) -> typing.List[str]:
"""A superset of the variables bound from an outer scope"""
return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"})

Expand Down
3 changes: 2 additions & 1 deletion tests/test_Unions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import hypothesis
import pytest
from hypothesis import given
from hypothesis import given, example
from hypothesis import strategies as st
from opshin import builder
from .utils import eval_uplc_value, eval_uplc, eval_uplc_raw
Expand Down Expand Up @@ -398,6 +398,7 @@ def validator(x: int) -> int:
self.assertEqual(res, real)

@hypothesis.given(st.sampled_from(range(14)))
@example(2)
def test_Union_cast_ifexpr(self, x):
source_code = """
from dataclasses import dataclass
Expand Down
204 changes: 204 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,190 @@ def a(n: int) -> int:
ret = eval_uplc_value(source_code, Unit())
self.assertEqual(100, ret)

def test_mutual_recursion_even_odd(self):
source_code = """
def even(n: int) -> bool:
if n == 0:
return True
else:
return odd(n - 1)

def odd(n: int) -> bool:
if n == 0:
return False
else:
return even(n - 1)

def validator(n: int) -> int:
if even(n):
return 1
else:
return 0
"""
# Test with even number
ret = eval_uplc_value(source_code, 4)
self.assertEqual(1, ret)
# Test with odd number
ret = eval_uplc_value(source_code, 3)
self.assertEqual(0, ret)

def test_mutual_recursion_three_way(self):
source_code = """
def a(n: int) -> int:
if n <= 0:
return 1
else:
return b(n - 1)

def b(n: int) -> int:
if n <= 0:
return 2
else:
return c(n - 1)

def c(n: int) -> int:
if n <= 0:
return 3
else:
return a(n - 1)

def validator(n: int) -> int:
return a(n)
"""
# Test different values to verify the three-way recursion pattern
ret = eval_uplc_value(source_code, 0)
self.assertEqual(1, ret) # a(0) = 1
ret = eval_uplc_value(source_code, 1)
self.assertEqual(2, ret) # a(1) = b(0) = 2
ret = eval_uplc_value(source_code, 2)
self.assertEqual(3, ret) # a(2) = b(1) = c(0) = 3
ret = eval_uplc_value(source_code, 3)
self.assertEqual(1, ret) # a(3) = b(2) = c(1) = a(0) = 1

def test_mutual_recursion_nested_functions(self):
source_code = """
def validator(n: int) -> int:
def even_nested(x: int) -> bool:
if x == 0:
return True
else:
return odd_nested(x - 1)

def odd_nested(x: int) -> bool:
if x == 0:
return False
else:
return even_nested(x - 1)

if even_nested(n):
return 1
else:
return 0
"""
# Test nested mutual recursion
ret = eval_uplc_value(source_code, 4)
self.assertEqual(1, ret) # 4 is even
ret = eval_uplc_value(source_code, 3)
self.assertEqual(0, ret) # 3 is odd

def test_mutual_recursion_with_classes(self):
source_code = """
from typing import Dict, List, Union
from pycardano import Datum as Anything, PlutusData
from dataclasses import dataclass

def process_data(x: int) -> int:
obj = MyData(x)
return transform_data(obj)

@dataclass
class MyData(PlutusData):
CONSTR_ID = 0
value: int

def transform_data(data: MyData) -> int:
if data.value <= 0:
return 0
else:
return process_data(data.value - 1) + 1

def validator(n: int) -> int:
return process_data(n)
"""
# Test mutual recursion with classes defined between functions
ret = eval_uplc_value(source_code, 3)
self.assertEqual(3, ret) # process_data(3) should return 3
ret = eval_uplc_value(source_code, 0)
self.assertEqual(0, ret) # process_data(0) should return 0

def test_three_function_chain_depth_issue(self):
# Test for known runtime issue with function call chaining at depth >= 2
# This is separate from mutual recursion and affects any three-function chain
source_code = """
def a(n: int) -> int:
if n <= 0:
return 1
else:
return b(n - 1)

def b(n: int) -> int:
if n <= 0:
return 2
else:
return c(n - 1)

def c(n: int) -> int:
if n <= 0:
return 3
else:
return 42 # No recursion, just return a constant

def validator(n: int) -> int:
return a(n)
"""
# This should work for n=0,1 but fails at n=2 with runtime error
ret = eval_uplc_value(source_code, 0)
self.assertEqual(1, ret)
ret = eval_uplc_value(source_code, 1)
self.assertEqual(2, ret)
# This line causes the runtime error
ret = eval_uplc_value(source_code, 42)
self.assertEqual(42, ret)

def test_three_function_chain_depth_issue_reverted(self):
# same as previous test but with function order reverted
# to ensure the issue is with the chain depth and not function order
source_code = """
def c(n: int) -> int:
if n <= 0:
return 3
else:
return 42 # No recursion, just return a constant

def b(n: int) -> int:
if n <= 0:
return 2
else:
return c(n - 1)

def a(n: int) -> int:
if n <= 0:
return 1
else:
return b(n - 1)

def validator(n: int) -> int:
return a(n)
"""
# This should work for n=0,1 but fails at n=2 with runtime error
ret = eval_uplc_value(source_code, 0)
self.assertEqual(1, ret)
ret = eval_uplc_value(source_code, 1)
self.assertEqual(2, ret)
# This line causes the runtime error
ret = eval_uplc_value(source_code, 42)
self.assertEqual(42, ret)

@unittest.expectedFailure
def test_uninitialized_access(self):
source_code = """
Expand Down Expand Up @@ -2779,3 +2963,23 @@ def validator(a: int) -> int:
assert "int" in str(e) and "str" in str(
e
), "Type check did not fail with correct message"

def test_mutual_recursion(self):
source_code = """
def even(n: int) -> bool:
if n == 0:
return True
else:
return odd(n - 1)

def odd(n: int) -> bool:
if n == 0:
return False
else:
return even(n - 1)

def validator(a: int) -> int:
return 42 if even(a) else 0
"""
res = eval_uplc_value(source_code, 2)
self.assertEqual(res, 42)
Loading