In stratum/optimizer/_numeric_rewrites.py, the eliminate_two_op_chain function incorrectly handles the case where the second operation in an eliminable chain (e.g., log -> exp) has multiple outputs (fan-out).
Instead of rewiring all downstream operations to the original input, it currently clears the input's output list entirely if there's more than one output, or fails to properly update the input's references when multiple outputs exist. This results in "detached" nodes in the DAG and incorrect optimization results.
Root Cause
The issue lies in the else branch of eliminate_two_op_chain:
def eliminate_two_op_chain(op1, op2):
x = op1.inputs[0]
if len(op2.outputs) == 1:
y = op2.outputs[0]
y.replace_input(op2, x)
x.replace_output(op1, y)
else:
x.outputs = [] # <--- BUG: This detaches everything!
In
stratum/optimizer/_numeric_rewrites.py, theeliminate_two_op_chainfunction incorrectly handles the case where the second operation in an eliminable chain (e.g.,log->exp) has multiple outputs (fan-out).Instead of rewiring all downstream operations to the original input, it currently clears the input's output list entirely if there's more than one output, or fails to properly update the input's references when multiple outputs exist. This results in "detached" nodes in the DAG and incorrect optimization results.
Root Cause
The issue lies in the
elsebranch ofeliminate_two_op_chain: