Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
ccfd904
Add deep rewrite for mutual recursion support
nielstron Mar 6, 2026
2cab33c
Fix closure binding regression for mutual recursion and methods
nielstron Mar 6, 2026
6233f49
Limit transitive closure propagation to function deps
nielstron Mar 6, 2026
6e584e8
Add test cases for union expansion
nielstron Mar 6, 2026
8c3bfd0
Fix union-expansion call rewriting for mutual recursion
nielstron Mar 6, 2026
f81e6cb
Refactor union expansion call specialization for recursion
nielstron Mar 6, 2026
237ef83
Formatting
nielstron Mar 6, 2026
8c963e2
Fix
nielstron Mar 6, 2026
2b4f6ec
test: add e2e regressions for union expansion mutual recursion
nielstron Mar 6, 2026
c1498a0
Remove non-e2e-test
nielstron Mar 6, 2026
6b5e421
Cleanup
nielstron Mar 6, 2026
d346064
Fix testcase to expose actual compiler bug
nielstron Mar 6, 2026
f7b9ea2
Patch union expansion cleanly
nielstron Mar 6, 2026
eb305aa
Merge branch 'feat/mutual-recursion' of github.com:OpShin/opshin into…
nielstron Mar 6, 2026
6bd6ea5
Fix union expansion properly
nielstron Mar 6, 2026
961e00b
Cleanup
nielstron Mar 6, 2026
2fe6c72
Fix
nielstron Mar 6, 2026
2fedcd8
Add failing test
nielstron Mar 6, 2026
bb9b8ec
Patch
nielstron Mar 7, 2026
d60340c
Patch
nielstron Mar 7, 2026
1a06e3e
Patch
nielstron Mar 7, 2026
371b696
Fix mutual recursion closure binding
nielstron Mar 7, 2026
b14e2ce
Cap UPLC reprs in tests
nielstron Mar 7, 2026
fc45854
Fix retyping for local seq
nielstron Mar 8, 2026
9ce87e8
Split optimization rewrites
nielstron Mar 8, 2026
132abdb
Split out the function dep resolution
nielstron Mar 8, 2026
d51ed19
Leverage specialized function ids
nielstron Mar 8, 2026
6c49053
Formatting
nielstron Mar 8, 2026
7406c16
Merge branch 'dev' into feat/mutual-recursion
nielstron Mar 8, 2026
495d7ca
Extend recursion tests
nielstron Mar 8, 2026
0bd3c8f
Another deep rewrite
nielstron Mar 8, 2026
9e44072
Correctly handle (dunder) methods
nielstron Mar 8, 2026
ea074e1
Add some comments
nielstron Mar 8, 2026
0375be8
Patch the patch
nielstron Mar 9, 2026
4ee6a58
Simplify
nielstron Mar 9, 2026
21ad127
Merge remote-tracking branch 'origin/dev' into feat/mutual-recursion
nielstron Mar 9, 2026
b298a6a
Solidify changes
nielstron Mar 9, 2026
cadf4bc
Fold more features together
nielstron Mar 9, 2026
f4064f4
Remove redundant step
nielstron Mar 9, 2026
3527630
Simplify union expansion
nielstron Mar 9, 2026
8012b07
Increase Plutus VM Hypothesis deadline to 4 seconds
nielstron Mar 9, 2026
7ea5c83
Fix function type resolution for recursive closure rewrites
nielstron Mar 9, 2026
ad97c55
Rewrite function instance types by id after closure transforms
nielstron Mar 9, 2026
e3e0184
Defer function closure metadata and validate rewrite consistency
nielstron Mar 9, 2026
2f12dbc
Rename
nielstron Mar 9, 2026
11a7432
Fix Python 3.9 optional type annotation
nielstron Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@

from .compiler_config import DEFAULT_CONFIG
from .optimize.optimize_const_folding import OptimizeConstantFolding
from .rewrite.rewrite_expanded_union_calls import (
RewriteExpandedUnionCalls,
)
from .rewrite.rewrite_function_closures import (
RewriteFunctionClosures,
)
from .optimize.optimize_remove_deadconstants import OptimizeRemoveDeadConstants
from .optimize.optimize_remove_deadconds import OptimizeRemoveDeadConditions
from .optimize.optimize_fold_if_fallthrough import OptimizeFoldIfFallthrough
Expand Down Expand Up @@ -1234,6 +1240,8 @@ def compile(
RewriteAnnotateFallthrough(),
# The type inference needs to be run after complex python operations were rewritten
AggressiveTypeInferencer(config.allow_isinstance_anything),
(RewriteExpandedUnionCalls() if config.expand_union_types else NoOp()),
RewriteFunctionClosures(),
# Rewrites that circumvent the type inference or use its results
OptimizeFoldBoolCast(),
RewriteAssertNone(),
Expand Down
66 changes: 21 additions & 45 deletions opshin/optimize/optimize_fold_if_fallthrough.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
from ast import *
from copy import copy

from ..util import CompilingNodeTransformer
from ..typed_util import (
ScopedSequenceNodeTransformer,
annotate_compound_statement_fallthrough,
)

"""
If exactly one branch of an if-statement can fall through, fold the following
statements in the enclosing sequence into that branch.
"""


def sequence_can_fall_through(statements):
for stmt in statements:
if not getattr(stmt, "can_fall_through", True):
return False
return True


class OptimizeFoldIfFallthrough(CompilingNodeTransformer):
class OptimizeFoldIfFallthrough(ScopedSequenceNodeTransformer):
step = "Folding trailing statements into sole fallthrough if-branches"

def fold_sequence(self, statements):
folded = []
i = 0
while i < len(statements):
if statements[i] is None:
i += 1
continue
stmt = self.visit(statements[i])
if stmt is None:
i += 1
Expand All @@ -38,68 +37,45 @@ def fold_sequence(self, statements):
stmt.body = self.fold_sequence(stmt.body + trailing)
else:
stmt.orelse = self.fold_sequence(stmt.orelse + trailing)
stmt.body_can_fall_through = sequence_can_fall_through(stmt.body)
stmt.orelse_can_fall_through = sequence_can_fall_through(
stmt.orelse
)
stmt.can_fall_through = (
stmt.body_can_fall_through or stmt.orelse_can_fall_through
)
folded.append(stmt)
folded.append(annotate_compound_statement_fallthrough(stmt))
break
folded.append(stmt)
i += 1
return folded

def visit_Module(self, node: Module) -> Module:
node_cp = copy(node)
node_cp.body = self.fold_sequence(node.body)
node_cp.can_fall_through = sequence_can_fall_through(node_cp.body)
return node_cp
node_cp = super().visit_Module(node)
node_cp.body = self.fold_sequence(node_cp.body)
return annotate_compound_statement_fallthrough(node_cp)

def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
node_cp = copy(node)
node_cp.body = self.fold_sequence(node.body)
node_cp.body_can_fall_through = sequence_can_fall_through(node_cp.body)
node_cp.can_fall_through = True
return node_cp
node_cp = super().visit_FunctionDef(node)
node_cp.body = self.fold_sequence(node_cp.body)
return annotate_compound_statement_fallthrough(node_cp)

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
node_cp = copy(node)
node_cp.body = self.fold_sequence(node.body)
node_cp.body_can_fall_through = sequence_can_fall_through(node_cp.body)
node_cp.can_fall_through = True
return node_cp
node_cp = super().visit_ClassDef(node)
node_cp.body = self.fold_sequence(node_cp.body)
return annotate_compound_statement_fallthrough(node_cp)

def visit_If(self, node: If) -> If:
node_cp = copy(node)
node_cp.test = self.visit(node.test)
node_cp.body = self.fold_sequence(node.body)
node_cp.orelse = self.fold_sequence(node.orelse)
node_cp.body_can_fall_through = sequence_can_fall_through(node_cp.body)
node_cp.orelse_can_fall_through = sequence_can_fall_through(node_cp.orelse)
node_cp.can_fall_through = (
node_cp.body_can_fall_through or node_cp.orelse_can_fall_through
)
return node_cp
return annotate_compound_statement_fallthrough(node_cp)

def visit_While(self, node: While) -> While:
node_cp = copy(node)
node_cp.test = self.visit(node.test)
node_cp.body = self.fold_sequence(node.body)
node_cp.orelse = self.fold_sequence(node.orelse)
node_cp.body_can_fall_through = sequence_can_fall_through(node_cp.body)
node_cp.orelse_can_fall_through = sequence_can_fall_through(node_cp.orelse)
node_cp.can_fall_through = node_cp.orelse_can_fall_through
return node_cp
return annotate_compound_statement_fallthrough(node_cp)

def visit_For(self, node: For) -> For:
node_cp = copy(node)
node_cp.target = self.visit(node.target)
node_cp.iter = self.visit(node.iter)
node_cp.body = self.fold_sequence(node.body)
node_cp.orelse = self.fold_sequence(node.orelse)
node_cp.body_can_fall_through = sequence_can_fall_through(node_cp.body)
node_cp.orelse_can_fall_through = sequence_can_fall_through(node_cp.orelse)
node_cp.can_fall_through = node_cp.orelse_can_fall_through
return node_cp
return annotate_compound_statement_fallthrough(node_cp)
84 changes: 25 additions & 59 deletions opshin/optimize/optimize_remove_deadconds.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,53 @@
from ast import *
from copy import deepcopy, copy
from copy import copy
from typing import Any, Union

from ..util import CompilingNodeTransformer
from ..typed_util import FlatteningScopedSequenceNodeTransformer

"""
Removes if/while branches that are never executed
"""


class OptimizeRemoveDeadConditions(CompilingNodeTransformer):
def visit_FunctionDef(self, node: FunctionDef) -> Any:
node = copy(node)
node.body = self.visit_sequence(node.body)
return node
class OptimizeRemoveDeadConditions(FlatteningScopedSequenceNodeTransformer):
def expression_guaranteed_tf(self, expr: expr) -> Union[bool, None]:
"""
Returns True if the expression is guaranteed to be truthy.
Returns False if the expression is guaranteed to be falsy.
Returns None if it cannot be determined.

def visit_sequence(self, stmts):
new_stmts = []
for stmt in stmts:
s = self.visit(stmt)
if s is None:
continue
if isinstance(s, list):
new_stmts.extend(s)
else:
new_stmts.append(s)
return new_stmts
Needs to be run after self.visit has been called on expr.
"""
if isinstance(expr, Constant):
return bool(expr.value)
return None

def visit_If(self, node: If) -> Any:
node = copy(node)
node.test = self.visit(node.test)
node.body = self.visit_sequence(node.body)
node.orelse = self.visit_sequence(node.orelse)
if isinstance(node.test, Constant):
if node.test.value:
return node.body
else:
return node.orelse
test_value = self.expression_guaranteed_tf(node.test)
if test_value is True:
return node.body
if test_value is False:
return node.orelse
return node

def visit_While(self, node: While) -> Any:
node = copy(node)
node.test = self.visit(node.test)
node.body = self.visit_sequence(node.body)
node.orelse = self.visit_sequence(node.orelse)
if isinstance(node.test, Constant):
if node.test.value:
raise ValueError(
"While loop with constant True condition is not allowed (infinite loop)"
)
else:
return node.orelse
return node

def visit_IfExp(self, node: IfExp) -> Any:
node = copy(node)
node.test = self.visit(node.test)
node.body = self.visit(node.body)
node.orelse = self.visit(node.orelse)

# Simplify if the test condition is a constant
if isinstance(node.test, Constant):
if node.test.value:
return node.body
else:
return node.orelse

test_value = self.expression_guaranteed_tf(node.test)
if test_value is True:
raise ValueError(
"While loop with constant True condition is not allowed (infinite loop)"
)
if test_value is False:
return node.orelse
return node

# Expression simplification logic

def expression_guaranteed_tf(self, expr: expr) -> Union[bool, None]:
"""
Returns True if the expression is guaranteed to be truthy
Returns False if the expression is guaranteed to be falsy
Returns None if it cannot be determined

Needs to be run after self.visit has been called on expr
"""
if isinstance(expr, Constant):
return expr.value
return None

def visit_IfExp(self, node: IfExp) -> expr:
ex = copy(node)
ex.test = self.visit(ex.test)
Expand Down
38 changes: 12 additions & 26 deletions opshin/optimize/optimize_remove_unreachable.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,48 @@
from ast import *
from copy import copy

from ..util import CompilingNodeTransformer
from ..typed_util import ScopedSequenceNodeTransformer

"""
Removes statements that are unreachable because a previous statement in the same
sequence is known not to fall through.
"""


class OptimizeRemoveUnreachable(CompilingNodeTransformer):
class OptimizeRemoveUnreachable(ScopedSequenceNodeTransformer):
step = "Removing unreachable statements"

@staticmethod
def visit_sequence(statements, visitor):
def visit_sequence(self, statements):
visited = []
for stmt in statements:
stmt_cp = visitor.visit(stmt)
if stmt is None:
continue
stmt_cp = self.visit(stmt)
if stmt_cp is None:
continue
visited.append(stmt_cp)
if not getattr(stmt_cp, "can_fall_through", True):
break
return visited

def visit_Module(self, node: Module) -> Module:
node_cp = copy(node)
node_cp.body = self.visit_sequence(node.body, self)
return node_cp

def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
node_cp = copy(node)
node_cp.body = self.visit_sequence(node.body, self)
return node_cp

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
node_cp = copy(node)
node_cp.body = self.visit_sequence(node.body, self)
return node_cp

def visit_If(self, node: If) -> If:
node_cp = copy(node)
node_cp.test = self.visit(node.test)
node_cp.body = self.visit_sequence(node.body, self)
node_cp.orelse = self.visit_sequence(node.orelse, self)
node_cp.body = self.visit_sequence(node.body)
node_cp.orelse = self.visit_sequence(node.orelse)
return node_cp

def visit_While(self, node: While) -> While:
node_cp = copy(node)
node_cp.test = self.visit(node.test)
node_cp.body = self.visit_sequence(node.body, self)
node_cp.orelse = self.visit_sequence(node.orelse, self)
node_cp.body = self.visit_sequence(node.body)
node_cp.orelse = self.visit_sequence(node.orelse)
return node_cp

def visit_For(self, node: For) -> For:
node_cp = copy(node)
node_cp.target = self.visit(node.target)
node_cp.iter = self.visit(node.iter)
node_cp.body = self.visit_sequence(node.body, self)
node_cp.orelse = self.visit_sequence(node.orelse, self)
node_cp.body = self.visit_sequence(node.body)
node_cp.orelse = self.visit_sequence(node.orelse)
return node_cp
Loading