Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions stratum/optimizer/_numeric_rewrites.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from stratum.optimizer.ir._numeric_ops import NumericOp, NumericOpType
from stratum.optimizer._op_utils import rewrite_pass
from stratum.optimizer._op_utils import rewrite_pass, replace_op_in_outputs
from stratum.optimizer.ir._ops import Op

def match_two_op_chain(op_cls, type1, type2):
Expand All @@ -18,12 +18,8 @@ def eliminate_two_op_chain(op1, op2):
Rewires the DAG in-place so that op1's input connects directly to op2's output.
"""
x = op1.inputs[0]
if len(op2.outputs) == 1:
y = op2.outputs[0]
y.replace_input(op2, x)
x.replace_output(op1, y)
else:
x.outputs = []
x.outputs = [out for out in x.outputs if out is not op1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice finding, i forgot about the util method. can we change the order of line 21 and 22?
so first eliminate op1 as output of x, than update op2 outputs (and implicit add outputs to x), since we should do the copy of the list before we append op2 outputs

replace_op_in_outputs(op2, x)

def eliminate_two_op_chain_root_safe(op1: Op, op2: Op, root: Op) -> Op:
"""Wrapper around eliminate_two_op_chain that handles the case where
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import unittest
import stratum as st
import numpy as np
from stratum.optimizer._optimize import optimize
from stratum.optimizer.ir._numeric_ops import NumericOp, NumericOpType
from stratum.optimizer.ir._ops import ValueOp

class TestNumericComplexFanout(unittest.TestCase):

def test_fanout_before_and_after_chain(self):
"""
Scenario:
a -> [log, d]
log -> [exp]
exp -> [b, c]

Expected after optimization:
a -> [d, b, c]
"""
a = st.as_data_op(1.0)

# Branch 1: log -> exp -> [b, c]
log_a = a.skb.apply_func(np.log)
exp_log_a = log_a.skb.apply_func(np.exp)
b = exp_log_a + 1.0
c = exp_log_a + 2.0

# Branch 2: d
d = a + 10.0

# Root combining all branches
t1 = d + b
final = t1 + c

linearized_dag, *_ = optimize(final)

# 1. Find the original 'a' Op
a_op = linearized_dag[0]
self.assertIsInstance(a_op, ValueOp)
self.assertEqual(a_op.value, 1.0)

# 2. Verify 'a' has 3 outputs: d, b, and c
self.assertEqual(len(a_op.outputs), 3, f"Expected 3 outputs for 'a', found {len(a_op.outputs)}")

# All outputs of 'a' should be NumericOp(ADD) now (BinOp(+) is extracted to NumericOp)
for out in a_op.outputs:
self.assertIsInstance(out, NumericOp)
self.assertEqual(out.type, NumericOpType.ADD)
self.assertIn(a_op, out.inputs)

# 3. Check that no log or exp ops are left
for op in linearized_dag:
if isinstance(op, NumericOp):
self.assertNotIn(op.type, [NumericOpType.LOG, NumericOpType.EXP])

def test_fanout_on_op2_only(self):
"""
Scenario:
a -> log -> exp -> [b, c]

Expected after optimization:
a -> [b, c]
"""
a = st.as_data_op(1.0)
log_a = a.skb.apply_func(np.log)
exp_log_a = log_a.skb.apply_func(np.exp)
b = exp_log_a + 1.0
c = exp_log_a + 2.0
final = b + c

linearized_dag, *_ = optimize(final)

a_op = linearized_dag[0]
self.assertIsInstance(a_op, ValueOp)
self.assertEqual(a_op.value, 1.0)

self.assertEqual(len(a_op.outputs), 2)
for out in a_op.outputs:
self.assertIsInstance(out, NumericOp)
self.assertEqual(out.type, NumericOpType.ADD)

for op in linearized_dag:
if isinstance(op, NumericOp):
self.assertNotIn(op.type, [NumericOpType.LOG, NumericOpType.EXP])

def test_chain_is_root(self):
"""
Scenario:
a -> log -> exp (exp is the root)

Expected after optimization:
root is a
"""
a = st.as_data_op(1.0)
log_a = a.skb.apply_func(np.log)
exp_log_a = log_a.skb.apply_func(np.exp)

linearized_dag, *_ = optimize(exp_log_a)

a_op = linearized_dag[0]
self.assertIsInstance(a_op, ValueOp)
self.assertEqual(a_op.value, 1.0)

self.assertEqual(len(a_op.outputs), 0)
self.assertIs(linearized_dag[-1], a_op)

for op in linearized_dag:
if isinstance(op, NumericOp):
self.assertNotIn(op.type, [NumericOpType.LOG, NumericOpType.EXP])

def test_chain_is_root_with_other_fanout(self):
"""
Scenario:
a -> [log, d]
log -> exp -> BinOp (root)

Expected after optimization:
a -> [d, combined]
d -> combined (combined = a + d is the root)
"""
a = st.as_data_op(1.0)
log_a = a.skb.apply_func(np.log)
exp_log_a = log_a.skb.apply_func(np.exp)

# Add another branch so 'a' has fan-out
d = a + 10.0

combined = exp_log_a + d
linearized_dag, *_ = optimize(combined)

a_op = linearized_dag[0]
self.assertIsInstance(a_op, ValueOp)
self.assertEqual(a_op.value, 1.0)

# 'a' should now connect directly to the root (the BinOp from combined)
# and to 'd'.
self.assertEqual(len(a_op.outputs), 2)

# Verify no NumericOps remain
for op in linearized_dag:
if isinstance(op, NumericOp):
self.assertNotIn(op.type, [NumericOpType.LOG, NumericOpType.EXP])

if __name__ == "__main__":
unittest.main()
Loading