Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 94 additions & 7 deletions stratum/optimizer/ir/_numeric_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from stratum.optimizer.ir._ops import BinOp, CallOp, Op
from stratum.optimizer.ir._ops import BinOp, CallOp, Op, DATA_OP_PLACEHOLDER
import operator
import numpy as np
from enum import Enum
Expand All @@ -10,12 +10,33 @@ class NumericOpType(Enum):
SQRT = "sqrt"
ABS = "abs"
SQUARE = "square"
ADD = "add"
SUBTRACT = "subtract"
MULTIPLY = "multiply"
DIVIDE = "divide"

_ARITH_OP_MAP = {
operator.add: NumericOpType.ADD,
operator.sub: NumericOpType.SUBTRACT,
operator.mul: NumericOpType.MULTIPLY,
operator.truediv: NumericOpType.DIVIDE,
}

_NUMPY_BINARY_MAP = {
np.add: NumericOpType.ADD,
np.subtract: NumericOpType.SUBTRACT,
np.multiply: NumericOpType.MULTIPLY,
np.divide: NumericOpType.DIVIDE,
}

_BINARY_TYPES = frozenset(_ARITH_OP_MAP.values())
_BINARY_NUMPY_FUNCS = frozenset(_NUMPY_BINARY_MAP.keys())

class NumericOp(Op):
fields = ["func", "args", "kwargs", "type"]
fields = ["func", "args", "kwargs", "type", "constant", "reversed"]
func = None

def __init__(self, inputs, outputs, func=None, args=(), kwargs=None, type: NumericOpType = None):
def __init__(self, inputs=None, outputs=None, func=None, args=(), kwargs=None, type: NumericOpType = None, constant=None, reversed=False):
if func is not None:
if func is np.log:
self.type = NumericOpType.LOG
Expand All @@ -32,6 +53,18 @@ def __init__(self, inputs, outputs, func=None, args=(), kwargs=None, type: Numer
elif func is np.square:
self.type = NumericOpType.SQUARE
name = "square"
elif func is np.add:
self.type = NumericOpType.ADD
name = "add"
elif func is np.subtract:
self.type = NumericOpType.SUBTRACT
name = "subtract"
elif func is np.multiply:
self.type = NumericOpType.MULTIPLY
name = "multiply"
elif func is np.divide:
self.type = NumericOpType.DIVIDE
name = "divide"
else:
self.type = NumericOpType.GENERIC
self.func = func
Expand All @@ -47,6 +80,8 @@ def __init__(self, inputs, outputs, func=None, args=(), kwargs=None, type: Numer
super().__init__(name=name, inputs=inputs, outputs=outputs)
self.args = args
self.kwargs = kwargs or {}
self.constant = constant
self.reversed = reversed

def process(self, mode: str, environment: dict, inputs: list):
if self.type == NumericOpType.GENERIC:
Expand All @@ -61,26 +96,78 @@ def process(self, mode: str, environment: dict, inputs: list):
return np.abs(inputs[0])
elif self.type == NumericOpType.SQUARE:
return np.square(inputs[0])
elif self.type in _BINARY_TYPES:
left, right = (self.constant, inputs[0]) if self.reversed else (inputs[0], self.constant)
if self.type == NumericOpType.ADD:
return np.add(left, right)
elif self.type == NumericOpType.SUBTRACT:
return np.subtract(left, right)
elif self.type == NumericOpType.MULTIPLY:
return np.multiply(left, right)
elif self.type == NumericOpType.DIVIDE:
return np.divide(left, right)
else:
raise ValueError(f"Unsupported binary numeric operation type: {self.type}")
else:
raise ValueError(f"Unsupported numeric operation type: {self.type}")


def make_numeric_op(op: CallOp) -> NumericOp:
op.args = op.args[1:]
new_op = NumericOp(func=op.func, args=op.args, kwargs=op.kwargs, inputs=op.inputs, outputs=op.outputs)
return new_op
remaining_args = op.args[1:]
return NumericOp(func=op.func, args=remaining_args, kwargs=op.kwargs, inputs=op.inputs, outputs=op.outputs)

def make_binary_numeric_op(op: CallOp, type: NumericOpType) -> NumericOp:
args = op.args or ()
if len(args) == 2 and args[0] is DATA_OP_PLACEHOLDER:
constant, reversed = args[1], False
elif len(args) == 2 and args[1] is DATA_OP_PLACEHOLDER:
constant, reversed = args[0], True
else:
raise ValueError(
f"make_binary_numeric_op called with args that are not a single-placeholder pair: {args}"
)
return NumericOp(type=type, constant=constant, reversed=reversed, inputs=op.inputs, outputs=op.outputs)


def _is_binary_extractable(op: CallOp) -> bool:
args = op.args or ()
if len(args) != 2:
return False
l_ph = args[0] is DATA_OP_PLACEHOLDER
r_ph = args[1] is DATA_OP_PLACEHOLDER
return l_ph != r_ph

def extract_numeric_op(op: Op, root: Op) -> tuple[Op, bool]:
new_op = None
if isinstance(op, BinOp) and op.op is operator.pow and op.right == 2:
new_op = NumericOp(func=np.square, args=(), kwargs={}, inputs=op.inputs, outputs=op.outputs)
elif isinstance(op, BinOp) and op.op in _ARITH_OP_MAP:
l_ph = op.left is DATA_OP_PLACEHOLDER
r_ph = op.right is DATA_OP_PLACEHOLDER
if l_ph != r_ph: # var op const or const op var, not var op var
constant = op.right if l_ph else op.left
new_op = NumericOp(
type=_ARITH_OP_MAP[op.op],
constant=constant,
reversed=not l_ph, # True when const is on the left
inputs=op.inputs,
outputs=op.outputs,
)
elif isinstance(op, CallOp):
if op.func is np.log:
new_op = make_numeric_op(op)
elif op.func is np.exp:
new_op = make_numeric_op(op)
elif op.func is np.sqrt:
new_op = make_numeric_op(op)
elif op.func is np.abs:
new_op = make_numeric_op(op)
elif op.func is np.square:
new_op = make_numeric_op(op)
elif op.func in _NUMPY_BINARY_MAP and _is_binary_extractable(op):
new_op = make_binary_numeric_op(op, _NUMPY_BINARY_MAP[op.func])
# if op is some other function from np package, make a generic numeric op
elif op.func.__module__ == "numpy":
elif op.func.__module__ == "numpy" and op.func not in _BINARY_NUMPY_FUNCS:
new_op = make_numeric_op(op)

if new_op is None:
Expand Down
4 changes: 3 additions & 1 deletion stratum/optimizer/ir/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from skrub._data_ops._choosing import Choice
from skrub._data_ops._data_ops import DataOp, Apply, Value, CallMethod, Call, GetAttr, GetItem, BinOp as SkrubBinOp, Concat, Var, _wrap_estimator
from pandas import DataFrame
from polars import DataFrame as PlDataFrame
from polars import DataFrame as PlDataFrame, Series as PlSeries
import logging
import os
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -412,6 +412,8 @@ def process(self, mode: str, environment: dict, inputs: list):
_obj = next(input_iter)
_args = _resolve_args(self.args, input_iter)
_kwargs = _resolve_kwargs(self.kwargs, input_iter)
if self.method_name == "apply" and isinstance(_obj, PlSeries):
return _obj.map_elements(*_args, **_kwargs)
return _obj.__getattribute__(self.method_name)(*_args, **_kwargs)

class CallOp(Op):
Expand Down
3 changes: 3 additions & 0 deletions stratum/runtime/_object_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sys import getsizeof
from pandas import DataFrame, Series
from polars import DataFrame as PolarsDataFrame, Series as PolarsSeries
import numpy as np
from numpy import ndarray
from logging import getLogger
logger = getLogger(__name__)
Expand Down Expand Up @@ -47,5 +48,7 @@ def get_size_polars(obj):
def get_size_numpy(obj):
if isinstance(obj, ndarray):
return obj.nbytes
elif isinstance(obj, np.generic):
return obj.itemsize
else:
raise ValueError(f"Unsupported numpy type for memory estimation: {type(obj)}")
180 changes: 178 additions & 2 deletions stratum/tests/logical_optimizer/test_numeric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import stratum as st
import numpy as np
from sklearn.dummy import DummyRegressor
from stratum.optimizer.ir._numeric_ops import NumericOp
from stratum.optimizer.ir._numeric_ops import NumericOp, NumericOpType
from stratum.optimizer._optimize import optimize

class TestNumericOps(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -63,4 +64,179 @@ def test_unsupported_numeric_op(self):
op = NumericOp(inputs=[], outputs=None, func=np.cos)
op.type = "unsupported"
with self.assertRaises(ValueError):
op.process("fit", {}, [])
op.process("fit", {}, [])

def test_process_add_var_const(self):
op = NumericOp([], [], type=NumericOpType.ADD, constant=2.0, reversed=False)
result = op.process("fit", {}, [np.array([1.0, 2.0, 3.0])])
np.testing.assert_array_almost_equal(result, np.array([3.0, 4.0, 5.0]))

def test_process_add_const_var(self):
op = NumericOp([], [], type=NumericOpType.ADD, constant=10.0, reversed=True)
result = op.process("fit", {}, [np.array([1.0, 2.0, 3.0])])
np.testing.assert_array_almost_equal(result, np.array([11.0, 12.0, 13.0]))

def test_process_subtract_var_const(self):
op = NumericOp([], [], type=NumericOpType.SUBTRACT, constant=1.0, reversed=False)
result = op.process("fit", {}, [np.array([4.0, 5.0, 6.0])])
np.testing.assert_array_almost_equal(result, np.array([3.0, 4.0, 5.0]))

def test_process_subtract_const_var(self):
op = NumericOp([], [], type=NumericOpType.SUBTRACT, constant=10.0, reversed=True)
result = op.process("fit", {}, [np.array([1.0, 2.0, 3.0])])
np.testing.assert_array_almost_equal(result, np.array([9.0, 8.0, 7.0]))

def test_process_multiply_var_const(self):
op = NumericOp([], [], type=NumericOpType.MULTIPLY, constant=3.0, reversed=False)
result = op.process("fit", {}, [np.array([1.0, 2.0, 3.0])])
np.testing.assert_array_almost_equal(result, np.array([3.0, 6.0, 9.0]))

def test_process_multiply_const_var(self):
op = NumericOp([], [], type=NumericOpType.MULTIPLY, constant=2.0, reversed=True)
result = op.process("fit", {}, [np.array([1.0, 2.0, 3.0])])
np.testing.assert_array_almost_equal(result, np.array([2.0, 4.0, 6.0]))

def test_process_divide_var_const(self):
op = NumericOp([], [], type=NumericOpType.DIVIDE, constant=2.0, reversed=False)
result = op.process("fit", {}, [np.array([2.0, 4.0, 6.0])])
np.testing.assert_array_almost_equal(result, np.array([1.0, 2.0, 3.0]))

def test_process_divide_const_var(self):
op = NumericOp([], [], type=NumericOpType.DIVIDE, constant=12.0, reversed=True)
result = op.process("fit", {}, [np.array([2.0, 3.0, 4.0])])
np.testing.assert_array_almost_equal(result, np.array([6.0, 4.0, 3.0]))

def test_extract_add_var_const(self):
df = st.as_data_op(5)
t1 = df + 3
out, *_ = optimize(t1)
self.assertEqual(len(out), 2)
self.assertIsInstance(out[1], NumericOp)
self.assertEqual(out[1].type, NumericOpType.ADD)
self.assertEqual(out[1].constant, 3)
self.assertFalse(out[1].reversed)

def test_extract_add_const_var(self):
df = st.as_data_op(5)
t1 = 3 + df
out, *_ = optimize(t1)
self.assertEqual(len(out), 2)
self.assertIsInstance(out[1], NumericOp)
self.assertEqual(out[1].type, NumericOpType.ADD)
self.assertEqual(out[1].constant, 3)
self.assertTrue(out[1].reversed)

def test_extract_subtract_var_const(self):
df = st.as_data_op(5)
t1 = df - 2
out, *_ = optimize(t1)
self.assertEqual(len(out), 2)
self.assertIsInstance(out[1], NumericOp)
self.assertEqual(out[1].type, NumericOpType.SUBTRACT)

def test_extract_multiply_var_const(self):
df = st.as_data_op(5)
t1 = df * 4
out, *_ = optimize(t1)
self.assertEqual(len(out), 2)
self.assertIsInstance(out[1], NumericOp)
self.assertEqual(out[1].type, NumericOpType.MULTIPLY)

def test_extract_divide_var_const(self):
df = st.as_data_op(10)
t1 = df / 2
out, *_ = optimize(t1)
self.assertEqual(len(out), 2)
self.assertIsInstance(out[1], NumericOp)
self.assertEqual(out[1].type, NumericOpType.DIVIDE)

def test_no_extract_var_var(self):
"""BinOp(var + var) must not be converted — keep as BinOp."""
df1 = st.as_data_op(2)
df2 = st.as_data_op(3)
t1 = df1 + df2
out, *_ = optimize(t1)
binary_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.ADD]
self.assertEqual(len(binary_ops), 0)

def test_extract_add_produces_correct_result(self):
df = st.as_data_op(5)
t1 = df + 3
out, *_ = optimize(t1)
add_op = next(op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.ADD)
self.assertEqual(add_op.process("fit", {}, [5]), 8)

def test_extract_np_add_callop(self):
"""CallOp with np.add should be extracted to NumericOp ADD."""
df = st.as_data_op(5)
t1 = df.skb.apply_func(np.add, 3)
out, *_ = optimize(t1)
add_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.ADD]
self.assertEqual(len(add_ops), 1)

def test_extract_np_multiply_callop(self):
"""CallOp with np.multiply should be extracted to NumericOp MULTIPLY."""
df = st.as_data_op(5)
t1 = df.skb.apply_func(np.multiply, 4)
out, *_ = optimize(t1)
mul_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.MULTIPLY]
self.assertEqual(len(mul_ops), 1)

def test_no_extract_np_add_var_var(self):
"""apply_func(np.add, var) with two DataOp inputs must not produce a binary NumericOp."""
df1 = st.as_data_op(2)
df2 = st.as_data_op(3)
t1 = df1.skb.apply_func(np.add, df2)
out, *_ = optimize(t1)
binary_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.ADD]
self.assertEqual(len(binary_ops), 0)

def test_no_extract_np_subtract_var_var(self):
"""apply_func(np.subtract, var) with two DataOp inputs must not produce a binary NumericOp."""
df1 = st.as_data_op(5)
df2 = st.as_data_op(3)
t1 = df1.skb.apply_func(np.subtract, df2)
out, *_ = optimize(t1)
binary_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.SUBTRACT]
self.assertEqual(len(binary_ops), 0)

def test_no_extract_np_multiply_var_var(self):
"""apply_func(np.multiply, var) with two DataOp inputs must not produce a binary NumericOp."""
df1 = st.as_data_op(2)
df2 = st.as_data_op(3)
t1 = df1.skb.apply_func(np.multiply, df2)
out, *_ = optimize(t1)
binary_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.MULTIPLY]
self.assertEqual(len(binary_ops), 0)

def test_no_extract_np_divide_var_var(self):
"""apply_func(np.divide, var) with two DataOp inputs must not produce a binary NumericOp."""
df1 = st.as_data_op(6)
df2 = st.as_data_op(2)
t1 = df1.skb.apply_func(np.divide, df2)
out, *_ = optimize(t1)
binary_ops = [op for op in out if isinstance(op, NumericOp) and op.type == NumericOpType.DIVIDE]
self.assertEqual(len(binary_ops), 0)

def test_extract_subtract_const_var_produces_correct_result(self):
df = st.as_data_op(3)
t1 = 10 - df
out, *_ = optimize(t1)
op = next(o for o in out if isinstance(o, NumericOp) and o.type == NumericOpType.SUBTRACT)
self.assertEqual(op.process("fit", {}, [3]), 7)

def test_extract_divide_const_var_produces_correct_result(self):
df = st.as_data_op(4)
t1 = 12 / df
out, *_ = optimize(t1)
op = next(o for o in out if isinstance(o, NumericOp) and o.type == NumericOpType.DIVIDE)
self.assertEqual(op.process("fit", {}, [4]), 3.0)

def test_make_binary_numeric_op_raises_on_invalid_args(self):
"""make_binary_numeric_op must raise ValueError when neither or both args are placeholders."""
from stratum.optimizer.ir._numeric_ops import make_binary_numeric_op
from stratum.optimizer.ir._ops import CallOp
op = CallOp(func=np.add, args=None)
op.args = (1.0, 2.0) # neither arg is DATA_OP_PLACEHOLDER
with self.assertRaises(ValueError):
make_binary_numeric_op(op, NumericOpType.ADD)
Loading
Loading