diff --git a/benchmarks/runtime/buffer_pool_input_release.py b/benchmarks/runtime/buffer_pool_input_release.py new file mode 100644 index 00000000..ced9805e --- /dev/null +++ b/benchmarks/runtime/buffer_pool_input_release.py @@ -0,0 +1,94 @@ +import os +import sys + +from sklearn.model_selection import ShuffleSplit + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import stratum as skrub +import numpy as np +import pandas as pd +from time import sleep, perf_counter +from sklearn.dummy import DummyRegressor +from utils.memory_consumption_tracker import MemoryTracker +import matplotlib.pyplot as plt +import argparse +import logging +import polars as pl +logging.basicConfig(level=logging.INFO) +logging.getLogger("stratum").setLevel(logging.DEBUG) + +def dummy_func(x, t: float=0.1): + indices = np.arange(len(x)) + if isinstance(x, pl.DataFrame) or isinstance(x, pl.Series): + out = x[indices] + else: + out = x.iloc[indices] + sleep(t) + return out + +def main(use_skrub: bool, polars: bool): + with skrub.config_context(eager_data_ops=False): + n = 1_000_000 + if not os.path.exists(f"input_{n}.csv"): + cols = ["a", "b", "c", "d", "e", "f", "g", "h", "y"] + df = pd.DataFrame({col: np.random.random(n) for col in cols}, dtype=np.float64) + + print(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2} MB") + + df.to_csv(f"input_{n}.csv", index=False) + del df + print("CSV written") + + df = skrub.as_data_op(f"input_{n}.csv").skb.apply_func(pd.read_csv) + X = df.drop("y", axis=1).skb.mark_as_X() + y = df["y"] + y = y.skb.apply_func(dummy_func, t=1.0).skb.mark_as_y() + + for i in range(5): + X = X.skb.apply_func(dummy_func, t=0.3) + model = DummyRegressor() + + pred = X.skb.apply(model, y=y) + cv = ShuffleSplit(n_splits=1, test_size=0.2, random_state=42) + tracker = MemoryTracker(mode="process", interval_sec=0.02) + tracker.start() + + t0 = perf_counter() + try: + with skrub.config(scheduler=not use_skrub, stats=20, DEBUG=False, force_polars=polars): + search = pred.skb.make_grid_search(cv=cv, n_jobs=1, scoring="r2", fitted=True, refit=False) + finally: + samples = tracker.stop() + t1 = perf_counter() + + csv_path = os.path.join(os.path.dirname(__file__), f"memory_usage_{'skrub' if use_skrub else 'stratum'}.csv") + tracker.write_csv(csv_path) + + print(f"Time taken: {t1 - t0:.2f}s") + print(search.results_) + plot_memory(csv_path) + + +def plot_memory(csv_path: str): + data = pd.read_csv(csv_path) + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(data["time_sec"], data["rss_mb"], linewidth=1.5) + ax.set_xlabel("Time (s)") + ax.set_ylabel("RSS (MB)") + ax.set_title("Buffer Pool Benchmark - Memory Usage") + ax.grid(True, alpha=0.3) + fig.tight_layout() + + plot_path = csv_path.replace(".csv", ".pdf") + fig.savefig(plot_path, dpi=150) + print(f"Plot saved to {plot_path}") + plt.close(fig) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--skrub", action="store_true") + parser.add_argument("--polars", action="store_true") + args = parser.parse_args() + main(use_skrub=args.skrub, polars=args.polars) diff --git a/benchmarks/utils/memory_consumption_tracker.py b/benchmarks/utils/memory_consumption_tracker.py new file mode 100644 index 00000000..6adf8689 --- /dev/null +++ b/benchmarks/utils/memory_consumption_tracker.py @@ -0,0 +1,158 @@ +"""Reusable RSS memory tracker that runs in a separate process. + +Usage:: + + from memory_tracker import MemoryTracker + + tracker = MemoryTracker(dump_path="mem.csv") + tracker.start() + try: + ... # workload + finally: + samples = tracker.stop() # list of (wall_sec, rss_mb) + tracker.write_csv("memory_usage.csv", t0=start_time) +""" + +from __future__ import annotations + +import os +import signal +from multiprocessing import Event, Manager, Process +from time import perf_counter +from typing import Literal + +import psutil + + +def _flush_samples(samples_list, dump_path: str, flushed_count: int) -> int: + new_samples = list(samples_list[flushed_count:]) + if not new_samples: + return flushed_count + mode = "a" if flushed_count > 0 else "w" + with open(dump_path, mode) as f: + if flushed_count == 0: + f.write("time_sec,rss_mb\n") + for ts, rss_mb in new_samples: + f.write(f"{ts:.6f},{rss_mb:.2f}\n") + return flushed_count + len(new_samples) + + +def _tracker_loop( + pid: int, + stop_event, + samples_list, + interval_sec: float, + dump_path: str, + flush_every: int, +) -> None: + flushed = 0 + + def _sigterm_handler(signum, frame): + nonlocal flushed + flushed = _flush_samples(samples_list, dump_path, flushed) + raise SystemExit(1) + + signal.signal(signal.SIGTERM, _sigterm_handler) + + sample_count = 0 + + def _sample(t, rss_mb): + nonlocal sample_count, flushed + samples_list.append((t, rss_mb)) + sample_count += 1 + if sample_count % flush_every == 0: + flushed = _flush_samples(samples_list, dump_path, flushed) + + if pid != -1: + try: + proc = psutil.Process(pid) + except psutil.NoSuchProcess: + return + while not stop_event.is_set(): + try: + rss_bytes = proc.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + break + _sample(perf_counter(), rss_bytes / (1024 * 1024)) + stop_event.wait(interval_sec) + else: + parent = psutil.Process(os.getppid()) + while not stop_event.is_set(): + if not parent.is_running(): + break + _sample(perf_counter(), psutil.virtual_memory().used / (1024 * 1024)) + stop_event.wait(interval_sec) + + _flush_samples(samples_list, dump_path, flushed) + + +class MemoryTracker: + """Spawn a side-car process that polls RSS of the current (or system) memory. + + Parameters + ---------- + mode : "process" or "system" + "process" tracks the calling process's RSS. + "system" tracks total system memory used (useful with multi-process workloads). + interval_sec : float + Polling interval in seconds. + live_dump_path : str or None + If set, samples are incrementally flushed to this CSV while tracking. + flush_every : int + Number of samples between live flushes. + """ + + def __init__( + self, + *, + mode: Literal["process", "system"] = "process", + interval_sec: float = 0.1, + live_dump_path: str | None = "memory_usage_live.csv", + flush_every: int = 50, + ): + self.mode = mode + self.interval_sec = interval_sec + self.live_dump_path = live_dump_path or "memory_usage_live.csv" + self.flush_every = flush_every + + self._manager = Manager() + self._samples = self._manager.list() + self._stop = Event() + self._process: Process | None = None + self._t0: float | None = None + + def start(self) -> None: + pid = os.getpid() if self.mode == "process" else -1 + self._t0 = perf_counter() + self._process = Process( + target=_tracker_loop, + args=(pid, self._stop, self._samples, self.interval_sec, + self.live_dump_path, self.flush_every), + ) + self._process.start() + + def stop(self, timeout: float = 2.0) -> list[tuple[float, float]]: + """Signal the tracker to stop and return collected samples.""" + self._stop.set() + if self._process is not None: + self._process.join(timeout=timeout) + return list(self._samples) + + @property + def t0(self) -> float: + if self._t0 is None: + raise RuntimeError("Tracker has not been started yet") + return self._t0 + + def write_csv(self, path: str, *, t0: float | None = None) -> None: + """Write all samples to *path* with wall-clock times relative to *t0*.""" + t0 = t0 if t0 is not None else self._t0 + if t0 is None: + raise RuntimeError("t0 not available; pass it explicitly or call start() first") + samples = list(self._samples) + if not samples: + return + with open(path, "w") as f: + f.write("time_sec,rss_mb\n") + for ts, rss_mb in samples: + f.write(f"{ts - t0:.4f},{rss_mb:.2f}\n") diff --git a/stratum/_api.py b/stratum/_api.py index da7a1f8f..3c29bbb0 100644 --- a/stratum/_api.py +++ b/stratum/_api.py @@ -14,8 +14,8 @@ def grid_search(dag: DataOp, cv=None, scoring=None, return_predictions=False, en env = dag.skb.get_data() for k, v in env_extra.items(): env[k] = v - dag = optimize(dag) - sched = SequentialScheduler(dag, show_stats, env=env, t0=t0) + linearized_dag, split_pos, flagged_ops = optimize(dag) + sched = SequentialScheduler(linearized_dag, split_pos, flagged_ops, show_stats, env=env, t0=t0) preds = sched.grid_search(cv, scoring, return_predictions) @@ -34,5 +34,5 @@ def grid_search(dag: DataOp, cv=None, scoring=None, return_predictions=False, en def evaluate(dag: DataOp, seed: int = 42, test_size = 0.2, cse: bool = False): """Evaluate a DataOp DAG with train/test split.""" - ops_ordered = optimize(dag) - return SequentialScheduler(ops_ordered).evaluate(seed, test_size) \ No newline at end of file + linearized_dag, split_pos, flagged_ops = optimize(dag) + return SequentialScheduler(linearized_dag, split_pos, flagged_ops).evaluate(seed, test_size) \ No newline at end of file diff --git a/stratum/optimizer/_input_release_planning.py b/stratum/optimizer/_input_release_planning.py new file mode 100644 index 00000000..4642d23b --- /dev/null +++ b/stratum/optimizer/_input_release_planning.py @@ -0,0 +1,64 @@ +"""Input release planning for the linearized Op DAG. + +After linearization, each op produces an intermediate buffer stored in the +BufferPool. Release planning decides *when* each buffer can be freed so +that peak memory stays low while correctness is preserved. + +How it works: +- Each op starts with a consumer count equal to ``len(op.outputs)`` (min 1). +- Walking the linearized order, every time an op appears as an input of + another op, its remaining count is decremented. +- When the count hits zero the buffer is scheduled for release right after + the current op finishes (set via ``op.release_after``). +- Pinned ops (pre-split ops that feed post-split / re-executed ops) are + never released by the schedule; they persist across CV folds. +- At execution time ``Op.release_inputs()`` calls ``buffers.release()`` for + every entry in its ``release_after`` list. +""" +from __future__ import annotations + +from stratum.optimizer.ir._ops import Op + +import logging +logger = logging.getLogger(__name__) + + +def compute_pinned_ops( + linearized_dag: list[Op], + split_pos: int | None, + recompute_ops: list[Op], +) -> set[Op]: + """Return pre-split ops whose buffers must persist across CV folds.""" + if split_pos is None: + return set() + + pinned: set[Op] = set() + re_executed = set(linearized_dag[split_pos:]) | set(recompute_ops) + for op in linearized_dag[:split_pos]: + if op not in re_executed: + for out_op in op.outputs: + if out_op in re_executed: + pinned.add(op) + break + return pinned + + +def plan_input_releases( + linearized_dag: list[Op], + pinned_ops: set[Op], +) -> None: + """Set ``op.release_after`` for every op in the linearized DAG.""" + remaining = {op: max(len(op.outputs), 1) for op in linearized_dag} + + for op in linearized_dag: + release = [] + for in_op in op.inputs: + remaining[in_op] -= 1 + if remaining[in_op] <= 0 and in_op not in pinned_ops: + release.append(in_op) + op.release_after = release + + logger.debug( + f"Release planning done: pinned={len(pinned_ops)} ops, " + f"total={len(linearized_dag)} ops" + ) diff --git a/stratum/optimizer/_linearization.py b/stratum/optimizer/_linearization.py new file mode 100644 index 00000000..fa06fe4b --- /dev/null +++ b/stratum/optimizer/_linearization.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from skrub._data_ops._data_ops import EvalMode +from stratum.optimizer.ir._dataframe_ops import SplitOp +from stratum.optimizer._op_utils import compute_graph_node_indegree +from stratum.optimizer.ir._ops import ImplOp, Op + +import logging +logger = logging.getLogger(__name__) + + +def linearize_dag(dag_sink: Op) -> tuple[list[Op], int | None, list[Op]]: + """Topologically sort a DAG and enforce the split invariant. + + Single-pass DFS that defers the split op on the stack: whenever the + split op is ready but other ops are too, prefer the non-split ops. + This naturally ensures all non-descendants are processed first. + + Guarantees that every op at index >= split_pos is a descendant of the + split op (i.e. in its subtree). This makes downstream scheduling + trivial: ops[:split_pos] are pre-split, ops[split_pos:] are post-split. + + Returns: + linearized_dag: Linearized list of ops with the split invariant. + split_pos: Index of the SplitOp in linearized_dag, or None if absent. + flagged_ops: Ops flagged for recomputation (ImplOps with EvalMode). + """ + indegree, sources = compute_graph_node_indegree(dag_sink) + + linearized_ops = [] + flagged_ops = [] + stack = list(sources) + split_pos = None + i = 0 + while stack: + if stack[-1].is_split_op and len(stack) > 1: + # Defer split op: pop something else instead + op = stack.pop(-2) + else: + op = stack.pop() + + if op.is_split_op: + split_pos = i + if isinstance(op, ImplOp) and isinstance(op.skrub_impl, EvalMode): + flagged_ops.append(op) + + linearized_ops.append(op) + for out_op in op.outputs: + if out_op not in indegree: + raise RuntimeError( + f"Encountered op {out_op} which should not exist in the DAG. " + f"Probably due to a buggy rewrite, which did not update its inputs / outputs correctly." + ) + indegree[out_op] -= 1 + if indegree[out_op] == 0: + stack.append(out_op) + i += 1 + + return linearized_ops, split_pos, flagged_ops diff --git a/stratum/optimizer/_op_utils.py b/stratum/optimizer/_op_utils.py index 16d12737..617d4a37 100644 --- a/stratum/optimizer/_op_utils.py +++ b/stratum/optimizer/_op_utils.py @@ -51,9 +51,9 @@ def get_all_outputs(op: Op, stop_at_op: Op = None): visited.add(out_) queue.append(out_) inputs_internal[out_] = [node] - + return list(visited), inputs_internal - + def clone_sub_dag(root_op: Op, stop_at_op: Op = None, new_root_op: Op = None): """Clones a sub-dag of the given Op. Excluding the given Op, but including all its internal outputs. @@ -104,7 +104,16 @@ def topological_iterator(root: Op) -> Iterator[Op]: """ Iterate over the Op DAG in topological order. """ + indegree, queue2 = compute_graph_node_indegree(root) + + # now we can do topological traversal + if FLAGS.bfs: + return topological_iterator_bfs(queue2, indegree) + else: + return topological_iterator_dfs(queue2, indegree) + +def compute_graph_node_indegree(root: Op) -> tuple[deque[Op], dict[Op, int]]: # first we need to bfs for finding all sources in the dag queue1 = deque([root]) indegree = {root: 0 if not root.inputs else len(root.inputs)} @@ -117,16 +126,13 @@ def topological_iterator(root: Op) -> Iterator[Op]: for in_op in op.inputs: if in_op not in indegree: if in_op is DATA_OP_PLACEHOLDER: - raise RuntimeError(f"Encountered DATA_OP_PLACEHOLDER as input of op {op}, which should not happen.") + raise RuntimeError( + f"Encountered DATA_OP_PLACEHOLDER as input of op {op}, which should not happen.") curr_indegree = len(in_op.inputs) indegree[in_op] = curr_indegree queue1.append(in_op) + return indegree, queue2 - # now we can do topological traversal - if FLAGS.bfs: - return topological_iterator_bfs(queue2, indegree) - else: - return topological_iterator_dfs(queue2, indegree) def topological_iterator_bfs(queue, indegree) -> Iterator[Op]: while queue: @@ -168,7 +174,7 @@ def show_graph(root: Op, filename: str = 'plan'): # make sure folder exists os.makedirs(os.path.dirname(filename), exist_ok=True) dot.render(filename, view=True,cleanup=True) - + def rewrite_pass(match_fn, action_fn): """Create a rewrite that does one full DAG pass. diff --git a/stratum/optimizer/_optimize.py b/stratum/optimizer/_optimize.py index 9ff11e64..44fe6551 100644 --- a/stratum/optimizer/_optimize.py +++ b/stratum/optimizer/_optimize.py @@ -8,6 +8,8 @@ from .ir._ops import ChoiceOp, ImplOp, Op, SearchEvalOp, as_op from ._op_utils import clone_sub_dag, find_choice_naive, replace_op_in_outputs, show_graph, topological_iterator from ._algebraic_rewrites import algebraic_rewrites, AlgebraicRewritesConfig +from ._linearization import linearize_dag +from ._input_release_planning import compute_pinned_ops, plan_input_releases from stratum.utils._skrub_graph import build_graph from time import perf_counter import logging @@ -106,9 +108,18 @@ def optimize(dag_root: DataOp, config: OptConfig = None): if config.algebraic_rewrites: root = time_pass("algebraic_rewrite", lambda x: algebraic_rewrites(x, config.algebraic_rewrite_config), root) + # Final passes: linearization and buffer release planning + linearized_dag, split_pos, flagged_ops = time_pass("linearization",linearize_dag,root) + + def release_planning(linearized_dag_, split_pos_, flagged_ops_): + pinned_ops = compute_pinned_ops(linearized_dag_, split_pos_, flagged_ops_) + plan_input_releases(linearized_dag_, pinned_ops) + + time_pass("release planning", release_planning,(linearized_dag, split_pos, flagged_ops)) + t1 = perf_counter() logger.info(f"Optimization took in total {t1 - t0:.2f} seconds") - return root + return linearized_dag, split_pos, flagged_ops def run_cse_pass(dag_root: DataOp, nodes: dict, order: list, parents: dict): diff --git a/stratum/optimizer/ir/_dataframe_ops.py b/stratum/optimizer/ir/_dataframe_ops.py index f0a22cc1..e78b3543 100644 --- a/stratum/optimizer/ir/_dataframe_ops.py +++ b/stratum/optimizer/ir/_dataframe_ops.py @@ -1,4 +1,5 @@ -from stratum.optimizer.ir._ops import DATA_OP_PLACEHOLDER, BaseEstimatorOp, BinOp, CallOp, GetAttrOp, GetItemOp, MethodCallOp, Op, ValueOp, VariableOp +from stratum.optimizer.ir._ops import (DATA_OP_PLACEHOLDER, BaseEstimatorOp, BinOp, CallOp, GetAttrOp, GetItemOp, + MethodCallOp, Op, ValueOp, VariableOp,_resolve_args, _resolve_kwargs) from pandas import DataFrame import pandas as pd import polars as pl @@ -24,19 +25,18 @@ def __init__(self, data: DataFrame = None, file_path: str = None, _format: str = self.read_kwargs = read_kwargs self.is_dataframe_op = True - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): if self.data is not None: if FLAGS.force_polars: - self.intermediate = pl.DataFrame(self.data) + return pl.DataFrame(self.data) else: - self.intermediate = self.data - + return self.data else: - file_path = self.inputs[0].intermediate if self.file_path is DATA_OP_PLACEHOLDER else self.file_path + file_path = inputs[0] if self.file_path is DATA_OP_PLACEHOLDER else self.file_path if FLAGS.force_polars: - self.intermediate = pl.read_csv(file_path, *self.read_args, **self.read_kwargs) + return pl.read_csv(file_path, *self.read_args, **self.read_kwargs) else: - self.intermediate = pd.read_csv(file_path, *self.read_args, **self.read_kwargs) + return pd.read_csv(file_path, *self.read_args, **self.read_kwargs) def clone(self): raise ValueError(f"We should not clone DataSourceOp objects.") @@ -53,17 +53,17 @@ def __init__(self, func: str, args: tuple | list = None, kwargs: dict = None, in self.kwargs = kwargs self.is_dataframe_op = True - def process(self, mode: str, environment: dict): - iter_ins = iter(self.inputs) - _obj = next(iter_ins).intermediate - _args = [next(iter_ins).intermediate if arg is DATA_OP_PLACEHOLDER else arg for arg in self.args] - _kwargs = {k: next(iter_ins).intermediate if v is DATA_OP_PLACEHOLDER else v for k, v in self.kwargs.items()} + def process(self, mode: str, environment: dict, inputs: list): + input_iter = iter(inputs) + _obj = next(input_iter) + _args = _resolve_args(self.args, input_iter) + _kwargs = _resolve_kwargs(self.kwargs, input_iter) if FLAGS.force_polars: if "columns" in _kwargs: _args.append(_kwargs["columns"]) - self.intermediate = getattr(_obj, self.func)(*_args) + return getattr(_obj, self.func)(*_args) else: - self.intermediate = getattr(_obj, self.func)(*_args, **_kwargs) + return getattr(_obj, self.func)(*_args, **_kwargs) class ProjectionOp(Op): fields = ["func", "is_method", "args", "kwargs", "columns"] @@ -80,25 +80,25 @@ def __init__(self, func="", is_method: bool = True, args: tuple | list = None, k self.kwargs = kwargs self.is_dataframe_op = True - def _extract_args_and_kwargs(self): + def _extract_args_and_kwargs(self, inputs: list): """Extract and process arguments and kwargs from inputs.""" - iter_ins, args_iter = iter(self.inputs), iter(self.args) - _obj = next(iter_ins).intermediate + input_iter, args_iter = iter(inputs), iter(self.args) + _obj = next(input_iter) if not self.is_method: next(args_iter) - _args = [next(iter_ins).intermediate if arg is DATA_OP_PLACEHOLDER else arg for arg in args_iter] - _kwargs = {k: next(iter_ins).intermediate if v is DATA_OP_PLACEHOLDER else v for k, v in self.kwargs.items()} + _args = _resolve_args(args_iter, input_iter) + _kwargs = _resolve_kwargs(self.kwargs, input_iter) return _obj, _args, _kwargs - def process(self, mode: str, environment: dict): - _obj, _args, _kwargs = self._extract_args_and_kwargs() + def process(self, mode: str, environment: dict, inputs: list): + _obj, _args, _kwargs = self._extract_args_and_kwargs(inputs) if self.is_method: if FLAGS.force_polars: raise ValueError(f"Unsupported method: {self.func}") else: - self.intermediate = getattr(_obj, self.func)(*_args, **_kwargs) + return getattr(_obj, self.func)(*_args, **_kwargs) else: - self.intermediate = self.func(_obj, *_args, **_kwargs) + return self.func(_obj, *_args, **_kwargs) class DropOp(ProjectionOp): fields = ["args", "kwargs", "columns"] @@ -106,17 +106,17 @@ def __init__(self, args: tuple | list = (), kwargs: dict = {}, inputs: list[Op] = None, outputs: list[Op] = None, columns: list[str] = None): super().__init__(args=args, kwargs=kwargs, inputs=inputs, outputs=outputs, columns=columns) - def process(self, mode: str, environment: dict): - _obj, _args, _kwargs = self._extract_args_and_kwargs() + def process(self, mode: str, environment: dict, inputs: list): + _obj, _args, _kwargs = self._extract_args_and_kwargs(inputs) if FLAGS.force_polars: if "columns" in _kwargs: _args.append(_kwargs["columns"]) if "ignore_errors" in _kwargs: _args.append(_kwargs["ignore_errors"] == "raise") - self.intermediate = _obj.drop(*_args) + return _obj.drop(*_args) else: - self.intermediate = _obj.drop(*_args, **_kwargs) + return _obj.drop(*_args, **_kwargs) class ApplyUDFOp(ProjectionOp): fields = ["args", "kwargs", "columns"] @@ -124,8 +124,8 @@ def __init__(self, args: tuple | list = (), kwargs: dict = {}, inputs: list[Op] = None, outputs: list[Op] = None, columns: list[str] = None): super().__init__(args=args, kwargs=kwargs, inputs=inputs, outputs=outputs, columns=columns) - def process(self, mode: str, environment: dict): - _obj, _args, _kwargs = self._extract_args_and_kwargs() + def process(self, mode: str, environment: dict, inputs: list): + _obj, _args, _kwargs = self._extract_args_and_kwargs(inputs) n_cols = None if self.columns: @@ -141,24 +141,24 @@ def process(self, mode: str, environment: dict): if n_cols == 1: if _args[0] == sin: logger.debug("Rewrite UDF sin to polars sin") - self.intermediate = _obj.sin() + return _obj.sin() elif _args[0] == cos: logger.debug("Rewrite UDF cos to polars cos") - self.intermediate = _obj.cos() + return _obj.cos() else: - self.intermediate = _obj.map_elements(*_args, **_kwargs) + return _obj.map_elements(*_args, **_kwargs) else: - self.intermediate = _obj.map_rows(*_args, **_kwargs) + return _obj.map_rows(*_args, **_kwargs) else: - self.intermediate = _obj.apply(*_args, **_kwargs) + return _obj.apply(*_args, **_kwargs) class AssignOp(ProjectionOp): def __init__(self, args: tuple | list = (), kwargs: dict = {}, inputs: list[Op] = None, outputs: list[Op] = None, columns: list[str] = None): super().__init__(args=args, kwargs=kwargs, inputs=inputs, outputs=outputs, columns=columns) - def process(self, mode: str, environment: dict): - _obj, _args, _kwargs = self._extract_args_and_kwargs() + def process(self, mode: str, environment: dict, inputs: list): + _obj, _args, _kwargs = self._extract_args_and_kwargs(inputs) if FLAGS.force_polars: checked_kwargs = {} for k, v in _kwargs.items(): @@ -169,9 +169,9 @@ def process(self, mode: str, environment: dict): checked_kwargs[k] = pl.from_pandas(v) else: checked_kwargs[k] = v - self.intermediate = _obj.with_columns(*_args, **checked_kwargs) + return _obj.with_columns(*_args, **checked_kwargs) else: - self.intermediate = _obj.assign(*_args, **_kwargs) + return _obj.assign(*_args, **_kwargs) class DatetimeConversionOp(ProjectionOp): def __init__(self, args: tuple | list = (), kwargs: dict = {}, @@ -179,11 +179,11 @@ def __init__(self, args: tuple | list = (), kwargs: dict = {}, super().__init__(args=args, inputs=inputs, outputs=outputs, columns=columns) self.strict = kwargs.get("errors", "raise") == "raise" - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): if FLAGS.force_polars: - self.intermediate = self.inputs[0].intermediate.str.to_datetime(*self.args, strict=self.strict) + return inputs[0].str.to_datetime(*self.args, strict=self.strict) else: - self.intermediate = pd.to_datetime(self.inputs[0].intermediate, *self.args, errors="raise" if self.strict else "coerce") + return pd.to_datetime(inputs[0], *self.args, errors="raise" if self.strict else "coerce") class GetAttrProjectionOp(Op): fields = ["attr_name"] @@ -210,37 +210,35 @@ def __str__(self): attr_name = ".".join(self.attr_name) return f"GetAttrProjectionOp({attr_name}) [df]" - def process(self, mode: str, environment: dict): - self.intermediate = self.inputs[0].intermediate - tmp = self.intermediate + def process(self, mode: str, environment: dict, inputs: list): + result = inputs[0] + tmp = result if FLAGS.force_polars: for attr in self.attr_name: attr = self.POLARS_ATTR_NAME_MAP.get(attr, attr) # TODO find better way to handle this if attr == "is_month_end": - self.intermediate = (self.intermediate.dt.month_end() == self.intermediate) - return + return result.dt.month_end() == result # polars implements dt.day as a method, not an attribute # use getattr to handle both attributes and methods tmp = getattr(tmp, attr) - self.intermediate = tmp() + return tmp() else: for attr in self.attr_name: tmp = getattr(tmp, attr) - self.intermediate = tmp - + return tmp class GroupedDataframeOp(Op): def __init__(self, ops: list[Op]): super().__init__(name="GROUPED_DATAFRAME", is_X=False, is_y=False) self.ops = ops self.is_dataframe_op = True - def process(self, mode: str, environment: dict): - for op in self.ops: - op.process(mode, environment) - self.intermediate = self.ops[-1].intermediate + def process(self, mode: str, environment: dict, inputs: list): # pragma: no cover + # TODO: GroupedDataframeOp is experimental and not integrated yet. + # Needs proper refactoring to collect sub-op inputs from the pool. + raise NotImplementedError("GroupedDataframeOp is not integrated yet.") class ConcatOp(Op): fields = ["first", "others", "axis"] # Add more if needed @@ -256,15 +254,15 @@ def __init__(self, first: Op, others: list[Op], axis: int): self.axis = DATA_OP_PLACEHOLDER if isinstance(axis, DataOp) else axis self.is_dataframe_op = True - def process(self, mode: str, environment: dict): - input_iter = iter(self.inputs) - first = next(input_iter).intermediate if self.first is DATA_OP_PLACEHOLDER else self.first - others = [next(input_iter).intermediate if other is DATA_OP_PLACEHOLDER else other for other in self.others] - axis = next(input_iter).intermediate if self.axis is DATA_OP_PLACEHOLDER else self.axis + def process(self, mode: str, environment: dict, inputs: list): + input_iter = iter(inputs) + first = next(input_iter) if self.first is DATA_OP_PLACEHOLDER else self.first + others = [next(input_iter) if other is DATA_OP_PLACEHOLDER else other for other in self.others] + axis = next(input_iter) if self.axis is DATA_OP_PLACEHOLDER else self.axis if FLAGS.force_polars: - self.intermediate = pl.concat([first, *others], how=self.axis_map[axis]) + return pl.concat([first, *others], how=self.axis_map[axis]) else: - self.intermediate = pd.concat([first, *others], axis=axis) + return pd.concat([first, *others], axis=axis) def rewrite_fuse_get_item_ops(op: Op) -> Op: @@ -294,14 +292,14 @@ def __init__(self, inputs: list[Op]=None, outputs: list[Op]=None): self.is_dataframe_op = True self.indices = None - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): # we need to handle both pandas and polars dfs - x = self.inputs[0].intermediate - y = self.inputs[1].intermediate + x = inputs[0] + y = inputs[1] if isinstance(x, pd.DataFrame): - self.intermediate = (x.iloc[self.indices], y.iloc[self.indices]) + return (x.iloc[self.indices], y.iloc[self.indices]) elif isinstance(x, pl.DataFrame): - self.intermediate = (x[self.indices], y[self.indices]) + return (x[self.indices], y[self.indices]) else: raise ValueError(f"Unsupported dataframe type: {type(x)}") @@ -312,11 +310,11 @@ def __init__(self, inputs: list[Op]=None, outputs: list[Op]=None, is_x = True, ) self.is_x = is_x self.is_dataframe_op = True - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): if self.is_x: - self.intermediate = self.inputs[0].intermediate[0] + return inputs[0][0] else: - self.intermediate = self.inputs[0].intermediate[1] + return inputs[0][1] def add_splitting_op(root: Op) -> Op: x_op = None diff --git a/stratum/optimizer/ir/_numeric_ops.py b/stratum/optimizer/ir/_numeric_ops.py index e1649725..6b6c0e51 100644 --- a/stratum/optimizer/ir/_numeric_ops.py +++ b/stratum/optimizer/ir/_numeric_ops.py @@ -26,13 +26,13 @@ def __init__(self, func, args, kwargs, inputs, outputs): self.args = args self.kwargs = kwargs - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): if self.type == NumericOpType.GENERIC: - self.intermediate = self.func(self.inputs[0].intermediate, *self.args, **self.kwargs) + return self.func(inputs[0], *self.args, **self.kwargs) elif self.type == NumericOpType.LOG: - self.intermediate = np.log(self.inputs[0].intermediate) + return np.log(inputs[0]) elif self.type == NumericOpType.EXP: - self.intermediate = np.exp(self.inputs[0].intermediate) + return np.exp(inputs[0]) else: raise ValueError(f"Unsupported numeric operation type: {self.type}") @@ -60,4 +60,4 @@ def extract_numeric_op(op: Op, root: Op) -> tuple[Op, bool]: op.replace_output_of_inputs(new_op) if op is root: root = new_op - return root, True + return root, True \ No newline at end of file diff --git a/stratum/optimizer/ir/_ops.py b/stratum/optimizer/ir/_ops.py index 1824fdb6..90843429 100644 --- a/stratum/optimizer/ir/_ops.py +++ b/stratum/optimizer/ir/_ops.py @@ -25,17 +25,27 @@ def __repr__(self): # unique identifier for arguments, which need to be replaced with Op references later DATA_OP_PLACEHOLDER = PlaceHolder("DATA_OP_PLACEHOLDER") + +def _resolve_args(args, input_iter): + """Replace DATA_OP_PLACEHOLDERs in an args sequence with values from input_iter.""" + return [next(input_iter) if a is DATA_OP_PLACEHOLDER else a for a in args] + + +def _resolve_kwargs(kwargs, input_iter): + """Replace DATA_OP_PLACEHOLDERs in a kwargs dict with values from input_iter.""" + return {k: next(input_iter) if v is DATA_OP_PLACEHOLDER else v for k, v in kwargs.items()} + class Op(): def __init__(self, inputs=None,outputs=None, name=None, is_X=False, is_y=False): self.name = name self.outputs = outputs if outputs is not None else [] self.inputs = inputs if inputs is not None else [] - self.intermediate = None self.is_X = is_X self.is_y = is_y self.is_dataframe_op = False self.is_split_op = False self.was_cloned = False + self.release_after: list[Op] = [] def to_str_helper(self): class_name = self.__class__.__name__ @@ -99,7 +109,14 @@ def clone(self): new_op.was_cloned = True return new_op - def process(self, mode: str, environment: dict): + def resolve_inputs(self, buffers): + return [buffers.get(in_op) for in_op in self.inputs] + + def release_inputs(self, buffers): + for in_op in self.release_after: + buffers.release(in_op) + + def process(self, mode: str, environment: dict, inputs: list): raise NotImplementedError(f"Processing of {self.__class__.__name__} objects is not implemented yet. Please implement it.") def check_kwargs(self, kwargs): @@ -131,14 +148,14 @@ def clone(self): new_op.was_cloned = True return new_op - def replace_fields_with_values(self): + def replace_fields_with_values(self, inputs): """Replace DataOp fields in implementation with their computed values.""" - parent_iter = iter(self.inputs) + input_iter = iter(inputs) def replace_dataop(value): """Recursively replace DataOp instances with their actual values.""" if isinstance(value, DataOp): - return next(parent_iter).intermediate + return next(input_iter) elif isinstance(value, (list, tuple)): new_seq = [replace_dataop(item) for item in value] return type(value)(new_seq) @@ -149,24 +166,23 @@ def replace_dataop(value): return SimpleNamespace(**{field: replace_dataop(getattr(self.skrub_impl, field)) for field in self.skrub_impl._fields}) - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): if hasattr(self.skrub_impl, "eval"): # DataOp with eval method have a fused implementation of the generator and the compute method # we need to iterate over the generator and replace the requested fields with correct inputs last_yield = None gen = self.skrub_impl.eval(mode=mode, environment=environment) - parent_iter = iter(self.inputs) + input_iter = iter(inputs) while True: try: last_yield = gen.send(last_yield) except StopIteration as e: - self.intermediate = e.value - break + return e.value if isinstance(last_yield, DataOp): - last_yield = next(parent_iter).intermediate + last_yield = next(input_iter) else: - ns = self.replace_fields_with_values() - self.intermediate = self.skrub_impl.compute(ns, mode, environment) + ns = self.replace_fields_with_values(inputs) + return self.skrub_impl.compute(ns, mode, environment) class VariableOp(Op): def __init__(self, name: str, value = None): @@ -180,8 +196,8 @@ def __init__(self, name: str, value = None): def clone(self): return VariableOp(name=self.name) - def process(self, mode: str, environment: dict): - self.intermediate = environment[self.name] + def process(self, mode: str, environment: dict, inputs: list): + return environment[self.name] class BaseEstimatorOp(Op): fields = ["estimator", "y", "cols", "how", "allow_reject", "unsupervised", "kwargs"] @@ -219,20 +235,20 @@ def clone(self): new_op.was_cloned = True return new_op - def extract_args_from_inputs(self, mode: str): + def extract_args_from_inputs(self, mode: str, inputs: list): """ Extract all necessary data from an EstimatorOp to make it picklable for multiprocessing. Returns a tuple of picklable data that can be sent to worker processes. """ - input_iter = iter(self.inputs) - x = next(input_iter).intermediate + input_iter = iter(inputs) + x = next(input_iter) assert x is not None, f"X is None for {self}" - y = None if mode == 'predict' else next(input_iter).intermediate if self.y == DATA_OP_PLACEHOLDER else self.y + y = None if mode == 'predict' else next(input_iter) if self.y == DATA_OP_PLACEHOLDER else self.y estm = self.estimator if mode == "predict" else self.original_estimator - place_holders = {k: next(input_iter).intermediate for k, v in estm.get_params().items() if isinstance(v, DataOp)} + place_holders = {k: next(input_iter) for k, v in estm.get_params().items() if isinstance(v, DataOp)} estm.set_params(**place_holders) - cols = next(input_iter).intermediate if self.cols == DATA_OP_PLACEHOLDER else self.cols + cols = next(input_iter) if self.cols == DATA_OP_PLACEHOLDER else self.cols return ( estm, x, @@ -246,11 +262,12 @@ def extract_args_from_inputs(self, mode: str): self.parallelism ) - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): # we use a separate function to process the estimator to allow reuse for multiprocessing - task_data = self.extract_args_from_inputs(mode) + task_data = self.extract_args_from_inputs(mode, inputs) process_task = self.get_process_task() - self.intermediate, self.estimator = process_task(task_data) + result, self.estimator = process_task(task_data) + return result def get_process_task(self): raise NotImplementedError(f"get_process_task must be implemented in {self.__class__.__name__}") @@ -369,9 +386,9 @@ def clone(self): new_op.was_cloned = True return new_op - def process(self, mode: str, environment: dict): - results = [{"id" : name, "vals" : self.inputs[i].intermediate} for i, name in enumerate(self.make_outcome_names())] - self.intermediate = results[0] if len(results) == 1 else results + def process(self, mode: str, environment: dict, inputs: list): + results = [{"id" : name, "vals" : inputs[i]} for i, name in enumerate(self.make_outcome_names())] + return results[0] if len(results) == 1 else results class ValueOp(Op): fields = ["value"] @@ -383,8 +400,8 @@ def __init__(self, value): def clone(self): raise ValueError(f"We should not clone ValueOp objects.") - def process(self, mode: str, environment: dict): - self.intermediate = self.value + def process(self, mode: str, environment: dict, inputs: list): + return self.value class MethodCallOp(Op): fields = ["method_name", "args", "kwargs"] @@ -397,12 +414,12 @@ def __init__(self, method_name: str, args = None, kwargs = None): self.args = remove_datops_from_args(args) if args is not None else args self.kwargs = remove_datops_from_args(kwargs) if kwargs is not None else kwargs - def process(self, mode: str, environment: dict): - iter_ins = iter(self.inputs) - _obj = next(iter_ins).intermediate - _args = [next(iter_ins).intermediate if arg is DATA_OP_PLACEHOLDER else arg for arg in self.args] - _kwargs = {k: next(iter_ins).intermediate if v is DATA_OP_PLACEHOLDER else v for k, v in self.kwargs.items()} - self.intermediate = _obj.__getattribute__(self.method_name)(*_args, **_kwargs) + def process(self, mode: str, environment: dict, inputs: list): + input_iter = iter(inputs) + _obj = next(input_iter) + _args = _resolve_args(self.args, input_iter) + _kwargs = _resolve_kwargs(self.kwargs, input_iter) + return _obj.__getattribute__(self.method_name)(*_args, **_kwargs) class CallOp(Op): fields = ["func", "args", "kwargs"] @@ -417,11 +434,11 @@ def __init__(self, name=None, func=None, args=None, kwargs=None): self.args = remove_datops_from_args(args) if args is not None else args self.kwargs = remove_datops_from_args(kwargs) if kwargs is not None else kwargs - def process(self, mode: str, environment: dict): - iter_ins = iter(self.inputs) - _args = [next(iter_ins).intermediate if arg is DATA_OP_PLACEHOLDER else arg for arg in self.args] - _kwargs = {k: next(iter_ins).intermediate if v is DATA_OP_PLACEHOLDER else v for k, v in self.kwargs.items()} - self.intermediate = self.func(*_args, **_kwargs) + def process(self, mode: str, environment: dict, inputs: list): + input_iter = iter(inputs) + _args = _resolve_args(self.args, input_iter) + _kwargs = _resolve_kwargs(self.kwargs, input_iter) + return self.func(*_args, **_kwargs) class GetAttrOp(Op): fields = ["attr_name"] @@ -430,13 +447,14 @@ def __init__(self, attr_name: str=None): super().__init__(name=attr_name if attr_name else '?') self.attr_name = attr_name - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): if self.is_dataframe_op: - self.intermediate = self.inputs[0].intermediate + result = inputs[0] for attr in self.attr_name: - self.intermediate = getattr(self.intermediate, attr) + result = getattr(result, attr) + return result else: - self.intermediate = getattr(self.inputs[0].intermediate, self.attr_name) + return getattr(inputs[0], self.attr_name) class GetItemOp(Op): fields = ["key"] @@ -447,11 +465,11 @@ def __init__(self, key=None): super().__init__(name=name) - def process(self, mode: str, environment: dict): + def process(self, mode: str, environment: dict, inputs: list): key = self.key if key is DATA_OP_PLACEHOLDER: - key = self.inputs[1].intermediate - self.intermediate = self.inputs[0].intermediate[key] + key = inputs[1] + return inputs[0][key] class BinOp(Op): fields = ["op", "left", "right"] @@ -463,19 +481,19 @@ def __init__(self, op: Callable, left, right): self.right = DATA_OP_PLACEHOLDER if isinstance(right, DataOp) else right - def process(self, mode: str, environment: dict, cv_id = None): + def process(self, mode: str, environment: dict, inputs: list): i = 0 if self.left is DATA_OP_PLACEHOLDER: - left = self.inputs[i].intermediate + left = inputs[i] i += 1 else: left = self.left if self.right is DATA_OP_PLACEHOLDER: - right = self.inputs[i].intermediate + right = inputs[i] i += 1 else: right = self.right - self.intermediate = self.op(left, right) + return self.op(left, right) class SearchEvalOp(Op): def __init__(self, outcome_names: list[str], parent: Op = None): diff --git a/stratum/runtime/_buffer_pool.py b/stratum/runtime/_buffer_pool.py new file mode 100644 index 00000000..0157f179 --- /dev/null +++ b/stratum/runtime/_buffer_pool.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import logging +from typing import Any, Hashable +logger = logging.getLogger(__name__) + +class BufferPool: + """Simple cache for intermediate buffers, will be replaced by a proper buffer in future. """ + + def __init__(self): + self._entries: dict[Hashable, Any] = {} # key -> data + self._released_count: int = 0 + + def put(self, key: Hashable, data: Any): + """Store data for a key. Overwrites any existing entry.""" + self._entries[key] = data + + def get(self, key: Hashable) -> Any: + """Retrieve stored data for a key, or None if not present.""" + return self._entries.get(key) + + def release(self, key: Hashable) -> bool: + """Release a single buffer, dropping its data.""" + entry = self._entries.pop(key, None) + if entry is not None: + logger.debug(f"Releasing buffer for {key}") + self._released_count += 1 + return True + return False + + def release_all(self) -> list: + """Release everything, including pinned. Used at end of execution. + + Returns list of released keys. + """ + released = list(self._entries.keys()) + for key in released: + self.release(key) + self._entries.clear() + return released + + @property + def active_count(self) -> int: + return len(self._entries) + + @property + def total_released(self) -> int: + return self._released_count diff --git a/stratum/runtime/_scheduler.py b/stratum/runtime/_scheduler.py index a82d614e..a8b93b03 100644 --- a/stratum/runtime/_scheduler.py +++ b/stratum/runtime/_scheduler.py @@ -1,11 +1,11 @@ +from __future__ import annotations from time import perf_counter from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split, check_cv from sklearn.metrics._scorer import _Scorer, get_scorer -from skrub._data_ops._data_ops import EvalMode from stratum.optimizer.ir._dataframe_ops import SplitOp -from stratum.optimizer._op_utils import topological_iterator -from stratum.optimizer.ir._ops import ImplOp, Op +from stratum.optimizer.ir._ops import Op +from stratum.runtime._buffer_pool import BufferPool import polars as pl import logging @@ -24,64 +24,72 @@ def get_scoring_func(scoring): scoring_func = mean_squared_error return scoring_func, greater_is_better + class Scheduler: - """Scheduler for executing DataOpDAGs in topological order.""" - - def __init__(self, print_heavy_hitters=False, env=None, t0 = None): - """Initialize scheduler with a data operations DAG.""" + """Scheduler for executing pre-planned Op DAGs in linearized order.""" + + def __init__(self, print_heavy_hitters=False, env=None, t0=None): self.mode = "fit_transform" self.env = env if env else {} - self.flagged_for_recomputation = [] - self.pos_split_op = None + self.linearized_dag = None + self.recompute_ops: list[Op] = [] + self.pos_split_op: int | None = None self.timings = [] if print_heavy_hitters else None self.results_ = None self.cv_id = -1 + self.pool = BufferPool() self.t0 = t0 if t0 is not None else perf_counter() + self._pinned_ops: set[Op] = set() + + def _finish(self): + """End of execution. Release all buffers.""" + self.pool.release_all() + logger.debug(f"Scheduler finished: {self.pool.total_released} buffers released total") - def evaluate(self, seed: int = 42, test_size = 0.2): + + def evaluate(self, seed: int = 42, test_size=0.2): """Evaluate the pipeline with a train/test split and return predictions.""" try: split_op = self.compute_xy() except RuntimeError as e: if "X and y nodes not found in the DAG" in str(e): logger.warning("X and y nodes not found in the DAG, returning the last node") - return self.ops_ordered[-1].intermediate + return self.pool.get(self.linearized_dag[-1]) else: raise e - train_index, test_index = train_test_split(range(len(split_op.inputs[0].intermediate)), test_size=test_size, random_state=seed) + x_data = self.pool.get(split_op.inputs[0]) + train_index, test_index = train_test_split(range(len(x_data)), test_size=test_size, random_state=seed) split_op.indices = train_index self.compute(self.pos_split_op) split_op.indices = test_index pred = self.compute(self.pos_split_op, mode="predict") return pred["vals"][0] - def grid_search(self, cv=None, scoring=None, return_predictions=False): """Perform grid search with cross-validation on the logical DAG.""" - # default to scikit-learn's CV cv = check_cv(cv) - # start with computing till we reach the split op logger.debug("\n" + "="*100 + "\n" + "Starting grid search" + "\n" + "="*100 + "\n") split_op = self.compute_xy() + results, predictions = [], [] logger.debug("\n" + "="*100 + "\n" + "XY computed" + "\n" + "="*100 + "\n") results = self.cross_validate(split_op, cv, scoring, predictions, results, return_predictions) self.results_ = results + self._finish() return predictions if return_predictions else None def cross_validate(self, split_op, cv, scoring, predictions: list, results: list, return_predictions: bool): """Perform cross-validation on the logical DAG.""" scoring_func, greater_is_better = get_scoring_func(scoring) - # TODO we can parallelize over the folds - for i, (train_index, test_index) in enumerate(cv.split(split_op.inputs[0].intermediate)): + x_data = self.pool.get(split_op.inputs[0]) + for i, (train_index, test_index) in enumerate(cv.split(x_data)): self.cv_id = i logger.debug(f"CV Fold Nr. {i + 1}") - # fit and predict the pipeline split_op.indices = train_index self.compute(self.pos_split_op) logger.debug("\n" + "="*100 + "\n" + "Training done for fold " + str(i+1) + "\n" + "="*100 + "\n") @@ -91,7 +99,6 @@ def cross_validate(self, split_op, cv, scoring, predictions: list, results: list if return_predictions: predictions.append(df) - # scoring df = df.with_columns(df["vals"].map_elements(lambda pred: scoring_func(y_test, pl.Series(pred))).alias("scores")) df = df.drop("vals") results.append(df) @@ -106,7 +113,9 @@ def process_op(self, op: Op): try: t0 = perf_counter() if self.timings is not None else 0 - op.process(mode=self.mode, environment=self.env) + inputs = op.resolve_inputs(self.pool) + result = op.process(mode=self.mode, environment=self.env, inputs=inputs) + op.release_inputs(self.pool) if self.timings is not None: duration = perf_counter() - t0 self.timings.append((str(op), duration)) @@ -114,6 +123,9 @@ def process_op(self, op: Op): except Exception as e: raise RuntimeError(f"[{self.mode}] Error processing '{op}': {e}") + self.pool.put(op, result) + logger.debug(f"[{perf_counter() - self.t0:.2f}s] Pool size: {self.pool.active_count}") + return op def _format_predict_result(self, pred): @@ -125,59 +137,58 @@ def _format_predict_result(self, pred): else: return pl.DataFrame({"vals": [pred], "id": ["default"]}) - def _flag_op_for_recomputation_if_needed(self, op: Op): - """Helper method to flag an op for recomputation if it's an ImplOp with EvalMode.""" - if isinstance(op, ImplOp) and isinstance(op.skrub_impl, EvalMode): - self.flagged_for_recomputation.append(op) class SequentialScheduler(Scheduler): - def __init__(self, dag_sink: Op, print_heavy_hitters=False, env=None, t0 = None): + def __init__(self, linearized_dag, split_pos, recompute_ops, + print_heavy_hitters=False, env=None, t0=None): super().__init__(print_heavy_hitters, env=env, t0=t0) - self.ops_ordered = [op for op in topological_iterator(dag_sink)] + self.linearized_dag = linearized_dag + self.pos_split_op = split_pos + self.recompute_ops = recompute_ops - def evaluate(self, seed: int = 42, test_size = 0.2): + def evaluate(self, seed: int = 42, test_size=0.2): """Evaluate the pipeline with a train/test split and return predictions.""" + try: split_op = self.compute_xy() except RuntimeError as e: if "X and y nodes not found in the DAG" in str(e): logger.warning("X and y nodes not found in the DAG, returning the last node") - return self.ops_ordered[-1].intermediate + return self.pool.get(self.linearized_dag[-1]) else: raise e - train_index, test_index = train_test_split(range(len(split_op.inputs[0].intermediate)), test_size=test_size, random_state=seed) + x_data = self.pool.get(split_op.inputs[0]) + train_index, test_index = train_test_split(range(len(x_data)), test_size=test_size, random_state=seed) split_op.indices = train_index self.compute(self.pos_split_op) split_op.indices = test_index pred, _ = self.compute(self.pos_split_op, mode="predict") + self._finish() return pred["vals"][0] - def compute(self, start_pos: int, mode="fit_transform"): """Compute the pipeline from start_pos onwards with given inputs.""" - ops_to_compute = self.ops_ordered[start_pos:] - if len(self.flagged_for_recomputation) != 0: - ops_to_compute = self.flagged_for_recomputation + ops_to_compute + ops_to_compute = self.linearized_dag[start_pos:] + if len(self.recompute_ops) != 0: + ops_to_compute = self.recompute_ops + ops_to_compute self.mode = mode y_true = None for node in ops_to_compute: self.process_op(node) if mode == "predict" and isinstance(node, SplitOp): - y_true = node.intermediate[1] + y_true = self.pool.get(node)[1] if mode == "predict": - pred = self.ops_ordered[-1].intermediate + pred = self.pool.get(self.linearized_dag[-1]) return self._format_predict_result(pred), y_true return None def compute_xy(self) -> SplitOp: - """Compute nodes until X and y nodes are found and store them.""" - for i, op in enumerate(self.ops_ordered): + """Compute nodes until the split op is reached.""" + for i, op in enumerate(self.linearized_dag): if op.is_split_op: - self.pos_split_op = i return op self.process_op(op) - self._flag_op_for_recomputation_if_needed(op) - raise RuntimeError("X and y nodes not found in the DAG") \ No newline at end of file + raise RuntimeError("X and y nodes not found in the DAG") diff --git a/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py b/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py index bb539cb6..cddc25d7 100644 --- a/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py +++ b/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py @@ -4,7 +4,6 @@ from stratum.optimizer._optimize import optimize from stratum.optimizer._optimize import OptConfig from stratum.optimizer._algebraic_rewrites import AlgebraicRewritesConfig -from stratum.optimizer._op_utils import topological_iterator class TestCSE(unittest.TestCase): @@ -13,8 +12,7 @@ def test_log_exp1(self): t1 = df.skb.apply_func(np.log) t2 = t1.skb.apply_func(np.exp) - out = optimize(t2) - out = list(topological_iterator(out)) + out, *_ = optimize(t2) self.assertEqual(len(out), 1) self.assertEqual(out[0].value, 1) @@ -24,8 +22,7 @@ def test_log_exp2(self): t2 = t1.skb.apply_func(np.exp) t3 = t2.skb.apply_func(np.log1p) - out = optimize(t3) - out = list(topological_iterator(out)) + out, *_ = optimize(t3) self.assertEqual(len(out), 2) self.assertEqual(out[0].value, 1) @@ -34,8 +31,7 @@ def test_exp_log1(self): t1 = df.skb.apply_func(np.exp) t2 = t1.skb.apply_func(np.log) - out = optimize(t2) - out = list(topological_iterator(out)) + out, *_ = optimize(t2) self.assertEqual(len(out), 1) self.assertEqual(out[0].value, 1) @@ -45,8 +41,7 @@ def test_exp_log2(self): t2 = t1.skb.apply_func(np.log) t3 = t2.skb.apply_func(np.log1p) - out = optimize(t3) - out = list(topological_iterator(out)) + out, *_ = optimize(t3) self.assertEqual(len(out), 2) self.assertEqual(out[0].value, 1) @@ -56,8 +51,7 @@ def test_log_log1p(self): t1 = df.skb.apply_func(np.log) t2 = t1.skb.apply_func(np.log1p) - out = optimize(t2) - out = list(topological_iterator(out)) + out, *_ = optimize(t2) self.assertEqual(len(out), 3) def test_log_log1p_exp(self): @@ -66,8 +60,7 @@ def test_log_log1p_exp(self): t1 = df.skb.apply_func(np.log) t2 = t1.skb.apply_func(np.log1p) t3 = t2.skb.apply_func(np.exp) - out = optimize(t3) - out = list(topological_iterator(out)) + out, *_ = optimize(t3) self.assertEqual(len(out), 4) def test_log1p_log1p_exp(self): @@ -76,8 +69,7 @@ def test_log1p_log1p_exp(self): t1 = df.skb.apply_func(np.log1p) t2 = t1.skb.apply_func(np.log1p) t3 = t2.skb.apply_func(np.exp) - out = optimize(t3) - out = list(topological_iterator(out)) + out, *_ = optimize(t3) self.assertEqual(len(out), 4) def test_disable_log_exp_rewrite1(self): @@ -89,8 +81,7 @@ def test_disable_log_exp_rewrite1(self): algebraic_rewrites=True, algebraic_rewrite_config=AlgebraicRewritesConfig(log_exp=False), ) - out = optimize(t2, config=config) - out = list(topological_iterator(out)) + out, *_ = optimize(t2, config=config) self.assertEqual(len(out), 3) def test_disable_log_exp_rewrite2(self): @@ -102,7 +93,6 @@ def test_disable_log_exp_rewrite2(self): algebraic_rewrites=True, algebraic_rewrite_config=AlgebraicRewritesConfig(exp_log=False), ) - out = optimize(t2, config=config) - out = list(topological_iterator(out)) + out, *_ = optimize(t2, config=config) self.assertEqual(len(out), 1) diff --git a/stratum/tests/logical_optimizer/test_dataframe_ops.py b/stratum/tests/logical_optimizer/test_dataframe_ops.py index 3a38c41b..87e74382 100644 --- a/stratum/tests/logical_optimizer/test_dataframe_ops.py +++ b/stratum/tests/logical_optimizer/test_dataframe_ops.py @@ -13,13 +13,14 @@ ApplyUDFOp, AssignOp, ConcatOp, DataSourceOp, DatetimeConversionOp, DropOp, GetAttrProjectionOp, GroupedDataframeOp, MetadataOp, ProjectionOp, SplitOp, rewrite_fuse_get_item_ops,) -from stratum.optimizer._op_utils import topological_iterator from stratum.optimizer.ir._ops import DATA_OP_PLACEHOLDER, GetItemOp, MethodCallOp, Op from stratum.optimizer._optimize import OptConfig, optimize as optimize_ +from stratum.runtime._buffer_pool import BufferPool def optimize(dag, conf=None): - return list(topological_iterator(optimize_(dag, conf))) + linearized_dag, *_ = optimize_(dag, conf) + return linearized_dag def _inp(val): @@ -29,6 +30,11 @@ def _inp(val): return op +def _inputs_for(op): + """Extract intermediate values from op.inputs.""" + return [in_op.intermediate for in_op in op.inputs] + + class TestDataframeOps(unittest.TestCase): def setUp(self): self.df = pd.DataFrame({ @@ -104,8 +110,8 @@ def tearDown(self): def test_process_data_polars(self): df = pd.DataFrame({"a": [1, 2]}) op = DataSourceOp(data=df) - op.process("fit_transform", {}) - self.assertIsInstance(op.intermediate, pl.DataFrame) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertIsInstance(result, pl.DataFrame) def test_process_read_csv_polars(self): tmp = tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") @@ -113,8 +119,8 @@ def test_process_read_csv_polars(self): tmp.close() try: op = DataSourceOp(file_path=tmp.name, _format="csv", read_args=(), read_kwargs={}) - op.process("fit_transform", {}) - self.assertIsInstance(op.intermediate, pl.DataFrame) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertIsInstance(result, pl.DataFrame) finally: os.remove(tmp.name) @@ -131,16 +137,16 @@ def test_process_rename_polars(self): df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) op = MetadataOp(func="rename", args=(), kwargs={"columns": {"a": "x"}}) op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertIn("x", op.intermediate.columns) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertIn("x", result.columns) class TestProjectionOp(unittest.TestCase): def test_process_non_method(self): op = ProjectionOp(func=lambda df, v: df * v, is_method=False, args=(DATA_OP_PLACEHOLDER, 2), kwargs={}) op.inputs = [_inp(pd.DataFrame({"a": [1, 2]}))] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate["a"].tolist(), [2, 4]) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result["a"].tolist(), [2, 4]) def test_process_polars_method_raises(self): orig = FLAGS.force_polars @@ -149,7 +155,7 @@ def test_process_polars_method_raises(self): op = ProjectionOp(func="drop", is_method=True, args=(), kwargs={}) op.inputs = [_inp(pl.DataFrame({"a": [1]}))] with self.assertRaises(ValueError): - op.process("fit_transform", {}) + op.process("fit_transform", {}, _inputs_for(op)) finally: FLAGS.force_polars = orig @@ -166,8 +172,8 @@ def test_drop_with_columns_kwarg(self): df = pl.DataFrame({"a": [1], "b": [2], "c": [3]}) op = DropOp(args=(), kwargs={"columns": ["b"]}) op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertNotIn("b", op.intermediate.columns) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertNotIn("b", result.columns) class TestApplyUDFOp(unittest.TestCase): @@ -175,15 +181,15 @@ def test_single_column_str(self): df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) op = ApplyUDFOp(args=(lambda x: x * 10,), kwargs={}, columns="a") op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate.tolist(), [10, 20]) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result.tolist(), [10, 20]) def test_multi_column(self): df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) op = ApplyUDFOp(args=(lambda x: x * 2,), kwargs={}, columns=["a", "b"]) op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate["a"].tolist(), [2, 4]) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result["a"].tolist(), [2, 4]) def test_polars_sin_rewrite(self): orig = FLAGS.force_polars @@ -192,8 +198,8 @@ def test_polars_sin_rewrite(self): series = pl.Series("a", [0.0, np.pi / 2]) op = ApplyUDFOp(args=(np.sin,), kwargs={}) op.inputs = [_inp(series)] - op.process("fit_transform", {}) - self.assertAlmostEqual(op.intermediate[1], 1.0, places=5) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertAlmostEqual(result[1], 1.0, places=5) finally: FLAGS.force_polars = orig @@ -204,8 +210,8 @@ def test_polars_cos_rewrite(self): series = pl.Series("a", [0.0]) op = ApplyUDFOp(args=(np.cos,), kwargs={}) op.inputs = [_inp(series)] - op.process("fit_transform", {}) - self.assertAlmostEqual(op.intermediate[0], 1.0, places=5) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertAlmostEqual(result[0], 1.0, places=5) finally: FLAGS.force_polars = orig @@ -216,8 +222,8 @@ def test_polars_multi_col_map_rows(self): df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) op = ApplyUDFOp(args=(lambda row: (row[0] + row[1],),), kwargs={}, columns=["a", "b"]) op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertIsNotNone(op.intermediate) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertIsNotNone(result) finally: FLAGS.force_polars = orig @@ -234,22 +240,22 @@ def test_assign_polars(self): df = pl.DataFrame({"a": [1, 2]}) op = AssignOp(args=(), kwargs={"b": pl.Series([10, 20])}) op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertIn("b", op.intermediate.columns) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertIn("b", result.columns) def test_assign_polars_pandas_conversion(self): df = pl.DataFrame({"a": [1, 2]}) op = AssignOp(args=(), kwargs={"b": pd.Series([10, 20])}) op.inputs = [_inp(df)] - op.process("fit_transform", {}) - self.assertIn("b", op.intermediate.columns) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertIn("b", result.columns) def test_assign_polars_placeholder_raises(self): df = pl.DataFrame({"a": [1, 2]}) op = AssignOp(args=(), kwargs={"b": DATA_OP_PLACEHOLDER}) op.inputs = [_inp(df), _inp(DATA_OP_PLACEHOLDER)] with self.assertRaises(NotImplementedError): - op.process("fit_transform", {}) + op.process("fit_transform", {}, _inputs_for(op)) class TestDatetimeConversionOpPolars(unittest.TestCase): @@ -260,8 +266,8 @@ def test_polars_path(self): s = pl.Series("dt", ["2025-01-01", "2025-06-15"]) op = DatetimeConversionOp(args=(), kwargs={}) op.inputs = [_inp(s)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate.dtype, pl.Datetime) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result.dtype, pl.Datetime) finally: FLAGS.force_polars = orig @@ -281,8 +287,8 @@ def test_polars_process(self): try: s = pl.Series("dt", pd.to_datetime(["2025-01-15", "2025-06-20"])) op = GetAttrProjectionOp(attr_name=["dt", "year"], inputs=[_inp(s)], outputs=[]) - op.process("fit_transform", {}) - self.assertEqual(op.intermediate.to_list(), [2025, 2025]) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result.to_list(), [2025, 2025]) finally: FLAGS.force_polars = orig @@ -292,8 +298,8 @@ def test_polars_dayofweek(self): try: s = pl.Series("dt", pd.to_datetime(["2025-01-06"])) # Monday op = GetAttrProjectionOp(attr_name=["dt", "dayofweek"], inputs=[_inp(s)], outputs=[]) - op.process("fit_transform", {}) - self.assertEqual(op.intermediate.to_list(), [1]) # polars: Monday=1 + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result.to_list(), [1]) # polars: Monday=1 finally: FLAGS.force_polars = orig @@ -303,21 +309,17 @@ def test_polars_is_month_end(self): try: s = pl.Series("dt", pd.to_datetime(["2025-01-31", "2025-01-15"])) op = GetAttrProjectionOp(attr_name=["dt", "is_month_end"], inputs=[_inp(s)], outputs=[]) - op.process("fit_transform", {}) - self.assertEqual(op.intermediate.to_list(), [True, False]) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result.to_list(), [True, False]) finally: FLAGS.force_polars = orig class TestGroupedDataframeOp(unittest.TestCase): def test_process(self): - inner1 = Op() - inner1.process = lambda m, e: setattr(inner1, 'intermediate', 10) - inner2 = Op() - inner2.process = lambda m, e: setattr(inner2, 'intermediate', 20) - op = GroupedDataframeOp(ops=[inner1, inner2]) - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 20) + op = GroupedDataframeOp(ops=[Op(), Op()]) + with self.assertRaises(NotImplementedError): + op.process("fit_transform", {}, _inputs_for(op)) class TestConcatOpPolars(unittest.TestCase): @@ -331,8 +333,8 @@ def test_polars_concat(self): mock_dataop2 = MagicMock(spec=DataOp) op = ConcatOp(first=mock_dataop1, others=[mock_dataop2], axis=0) op.inputs = [_inp(df1), _inp(df2)] - op.process("fit_transform", {}) - self.assertEqual(len(op.intermediate), 4) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(len(result), 4) finally: FLAGS.force_polars = orig @@ -343,14 +345,14 @@ def test_polars(self): y = pl.DataFrame({"b": [1, 2, 3]}) op = SplitOp(inputs=[_inp(x), _inp(y)]) op.indices = [0, 2] - op.process("fit_transform", {}) - self.assertEqual(len(op.intermediate[0]), 2) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(len(result[0]), 2) def test_unsupported_type(self): op = SplitOp(inputs=[_inp("not_a_df"), _inp("not_a_df")]) op.indices = [0] with self.assertRaises(ValueError): - op.process("fit_transform", {}) + op.process("fit_transform", {}, _inputs_for(op)) class TestRewriteFuseGetItemOps(unittest.TestCase): @@ -369,10 +371,14 @@ def test_read_op_with_variable_input(self): with skrub.config(fast_dataops_convert=True): ops = optimize(data, OptConfig(dataframe_ops=True)) self.assertIsInstance(ops[-1], DataSourceOp) - # Verify it can actually process - ops[0].process("fit_transform", {"path": tmp.name}) - ops[1].process("fit_transform", {}) - self.assertIsInstance(ops[1].intermediate, pd.DataFrame) + # Verify it can actually process using resolve_inputs + pool = BufferPool() + inputs0 = ops[0].resolve_inputs(pool) + result0 = ops[0].process("fit_transform", {"path": tmp.name}, inputs0) + pool.put(ops[0], result0) + inputs1 = ops[1].resolve_inputs(pool) + result1 = ops[1].process("fit_transform", {}, inputs1) + self.assertIsInstance(result1, pd.DataFrame) finally: os.remove(tmp.name) diff --git a/stratum/tests/logical_optimizer/test_numeric_ops.py b/stratum/tests/logical_optimizer/test_numeric_ops.py index d6a73c27..f19fe3ea 100644 --- a/stratum/tests/logical_optimizer/test_numeric_ops.py +++ b/stratum/tests/logical_optimizer/test_numeric_ops.py @@ -3,7 +3,7 @@ import stratum as skrub import numpy as np from sklearn.dummy import DummyRegressor -from stratum.optimizer.ir._numeric_ops import NumericOp +from stratum.optimizer.ir._numeric_ops import NumericOp class TestNumericOps(unittest.TestCase): def setUp(self): @@ -28,4 +28,4 @@ def test_unsupported_numeric_op(self): op = NumericOp(np.cos, None, None, [], []) op.type = "unsupported" with self.assertRaises(ValueError): - op.process("fit", {}) \ No newline at end of file + op.process("fit", {}, []) \ No newline at end of file diff --git a/stratum/tests/logical_optimizer/test_op_utils.py b/stratum/tests/logical_optimizer/test_op_utils.py index 8d18b36c..cb2fcae7 100644 --- a/stratum/tests/logical_optimizer/test_op_utils.py +++ b/stratum/tests/logical_optimizer/test_op_utils.py @@ -1,13 +1,14 @@ #from curses import flash import unittest import stratum as skrub -from stratum.optimizer._optimize import optimize as optimize_, OptConfig, choice_unrolling +from stratum.optimizer._optimize import optimize as optimize_, OptConfig, choice_unrolling, convert_to_ops from stratum.optimizer._op_utils import show_graph, clone_sub_dag, topological_iterator, FLAGS from stratum._config import config graph = False def optimize(dag, conf=None): - return list(topological_iterator(optimize_(dag, conf))) + linearized_dag, *_ = optimize_(dag, conf) + return linearized_dag class TestOpUtils(unittest.TestCase): def setUp(self): @@ -21,7 +22,8 @@ def setUp(self): def test_iterator_bfs(self): FLAGS.bfs = True try: - ops = optimize(self.dag) + root = convert_to_ops(self.dag) + ops = list(topological_iterator(root)) finally: FLAGS.bfs = False self.assertEqual(ops[0].value, 1) @@ -152,5 +154,6 @@ def test_choice_unrolling(self): with config(open_graph=False): show_graph(out, filename='choice_unrolling') + diff --git a/stratum/tests/logical_optimizer/test_ops.py b/stratum/tests/logical_optimizer/test_ops.py index b8fc9305..134d8355 100644 --- a/stratum/tests/logical_optimizer/test_ops.py +++ b/stratum/tests/logical_optimizer/test_ops.py @@ -10,7 +10,6 @@ from sklearn.preprocessing import StandardScaler from skrub._data_ops._data_ops import DataOp -from stratum.optimizer._op_utils import topological_iterator from stratum.optimizer.ir._ops import ( DATA_OP_PLACEHOLDER, BinOp, CallOp, DummyConfigManager, GetAttrOp, GetItemOp, ImplOp, MethodCallOp, Op, PlaceHolder, SearchEvalOp, ValueOp, @@ -28,6 +27,11 @@ def _inp(val): return op +def _inputs_for(op): + """Extract intermediate values from op.inputs.""" + return [in_op.intermediate for in_op in op.inputs] + + class TestOpCloning(unittest.TestCase): def setUp(self): self.df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) @@ -53,7 +57,7 @@ def test_clone_ops(self): pred = pred.skb.apply_func(lambda x, a, b: x, 1, b=1) choice = skrub.choose_from([pred], name="choice").as_data_op() with skrub.config(fast_dataops_convert=True): - ops = list(topological_iterator(optimize_(choice.empty))) + ops, *_ = optimize_(choice.empty) with self.assertRaises(ValueError): ops[0].clone() @@ -113,8 +117,8 @@ def test_basics(self): cloned = op.clone() self.assertIsNot(op, cloned) self.assertEqual(cloned.name, "x") - op.process("fit_transform", {"x": 123}) - self.assertEqual(op.intermediate, 123) + result = op.process("fit_transform", {"x": 123}, _inputs_for(op)) + self.assertEqual(result, 123) class TestImplOp(unittest.TestCase): @@ -133,7 +137,8 @@ def test_replace_fields_with_values(self): cls = type("Impl", (), {"_fields": ["x", "y", "z"], "x": mock_dataop, "y": [mock_dataop, 5], "z": {"k": mock_dataop}}) op = ImplOp(name="test", skrub_impl=cls()) op.inputs = [_inp("vx"), _inp("vy"), _inp("vz")] - ns = op.replace_fields_with_values() + inputs = [in_op.intermediate for in_op in op.inputs] + ns = op.replace_fields_with_values(inputs) self.assertEqual(ns.x, "vx") self.assertEqual(ns.y[1], 5) self.assertEqual(ns.z["k"], "vz") @@ -145,15 +150,15 @@ def fake_eval(mode, environment): return val * 2 op = ImplOp(name="test", skrub_impl=SimpleNamespace(eval=fake_eval)) op.inputs = [_inp(10)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 20) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 20) def test_process_without_eval(self): cls = type("Impl", (), {"_fields": ["a"], "a": 42, "compute": lambda self, ns, mode, env: ns.a + 1}) op = ImplOp(name="test", skrub_impl=cls()) op.inputs = [] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 43) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 43) class TestUtilFunctions(unittest.TestCase): @@ -206,49 +211,49 @@ class TestOpProcess(unittest.TestCase): def test_method_call(self): op = MethodCallOp("upper", args=(), kwargs={}) op.inputs = [_inp("hello")] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, "HELLO") + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, "HELLO") def test_method_call_with_placeholders(self): op = MethodCallOp("format", args=(DATA_OP_PLACEHOLDER,), kwargs={"end": DATA_OP_PLACEHOLDER}) op.inputs = [_inp("{0} {end}"), _inp("hello"), _inp("world")] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, "hello world") + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, "hello world") def test_call_op(self): op = CallOp(func=lambda a, b: a + b, args=(DATA_OP_PLACEHOLDER, DATA_OP_PLACEHOLDER), kwargs={}) op.inputs = [_inp(3), _inp(7)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 10) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 10) def test_getattr_dataframe_op(self): op = GetAttrOp(attr_name=["real", "imag"]) op.is_dataframe_op = True op.inputs = [_inp(1 + 2j)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 0.0) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 0.0) def test_getattr_normal(self): op = GetAttrOp(attr_name="real") op.inputs = [_inp(3 + 4j)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 3.0) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 3.0) def test_getitem_with_placeholder(self): op = GetItemOp(key="dummy") op.key = DATA_OP_PLACEHOLDER op.inputs = [_inp({"x": 42}), _inp("x")] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 42) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 42) def test_binop_both_placeholders(self): op = BinOp(op=operator.add, left=DATA_OP_PLACEHOLDER, right=DATA_OP_PLACEHOLDER) op.inputs = [_inp(10), _inp(20)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 30) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 30) def test_binop_left_literal(self): op = BinOp(op=operator.mul, left=5, right=DATA_OP_PLACEHOLDER) op.inputs = [_inp(3)] - op.process("fit_transform", {}) - self.assertEqual(op.intermediate, 15) + result = op.process("fit_transform", {}, _inputs_for(op)) + self.assertEqual(result, 15) diff --git a/stratum/tests/logical_optimizer/test_optimize.py b/stratum/tests/logical_optimizer/test_optimize.py index d0b6d157..731f2724 100644 --- a/stratum/tests/logical_optimizer/test_optimize.py +++ b/stratum/tests/logical_optimizer/test_optimize.py @@ -28,7 +28,7 @@ def test_optimize1(self): X2 = X1.assign( year=X1["datetime"].dt.year, month=X1["datetime"].dt.month) - out = list(topological_iterator(optimize(X2, OptConfig(cse=True)))) + out, *_ = optimize(X2, OptConfig(cse=True)) self.assertTrue(out[0].outputs[0] is out[1]) self.assertTrue(len(out[0].inputs) == 0) @@ -41,7 +41,7 @@ def test_optimize2(self): year=X1["datetime"].dt.year, month=X1["datetime"].dt.month) config = OptConfig(cse=False, algebraic_rewrites=False, numeric_ops=False, dataframe_ops=False, unroll_choices=False) - out = list(topological_iterator(optimize(X2, config))) + out, *_ = optimize(X2, config) self.assertEqual(len(out), 10) def test_more_ops(self): @@ -51,7 +51,7 @@ def test_more_ops(self): X2 = X1.assign( year=X1["datetime"].dt.year, month=X1["datetime"].dt.month) - out = optimize(X2, OptConfig(cse=True)) + out, *_ = optimize(X2, OptConfig(cse=True)) diff --git a/stratum/tests/runtime/runtime_test_utils.py b/stratum/tests/runtime/runtime_test_utils.py index 43df77e4..caaae46b 100644 --- a/stratum/tests/runtime/runtime_test_utils.py +++ b/stratum/tests/runtime/runtime_test_utils.py @@ -6,6 +6,7 @@ from stratum._api import evaluate import stratum as skrub from sklearn.dummy import DummyRegressor +from stratum.optimizer.ir._ops import Op def datetime_pipeline1(x: DataOp, y: DataOp) -> DataOp: @@ -72,3 +73,35 @@ def compare_evaluate(self, pred_opt: DataOp): preds_skrub = learner.predict(splits["test"]) np.testing.assert_array_equal(preds_skrub, preds) + +def _make_op(name="op"): + """Create a minimal Op (no intermediate attribute — data lives in the pool).""" + return Op(name=name) + + +def _make_linear_dag(): + """A -> B -> C (linear chain).""" + a = _make_op("A") + b = _make_op("B") + c = _make_op("C") + a.outputs = [b] + b.inputs = [a] + b.outputs = [c] + c.inputs = [b] + return [a, b, c] + + +def _make_diamond_dag(): + """A -> B, A -> C, B -> D, C -> D (diamond).""" + a = _make_op("A") + b = _make_op("B") + c = _make_op("C") + d = _make_op("D") + a.outputs = [b, c] + b.inputs = [a] + b.outputs = [d] + c.inputs = [a] + c.outputs = [d] + d.inputs = [b, c] + return [a, b, c, d] + diff --git a/stratum/tests/runtime/test_buffer_pool.py b/stratum/tests/runtime/test_buffer_pool.py new file mode 100644 index 00000000..ead786fa --- /dev/null +++ b/stratum/tests/runtime/test_buffer_pool.py @@ -0,0 +1,74 @@ +import unittest +from stratum.runtime._buffer_pool import BufferPool +from stratum.tests.runtime.runtime_test_utils import RuntimeTest, simple_pipeline, _make_op +from stratum._api import grid_search + +class TestBufferPool(unittest.TestCase): + """Tests for BufferPool as a pure cache.""" + + def test_put_and_get(self): + pool = BufferPool() + op = _make_op("x") + pool.put(op, "data_x") + self.assertEqual(pool.get(op), "data_x") + self.assertEqual(pool.active_count, 1) + + def test_get_missing_returns_none(self): + pool = BufferPool() + self.assertIsNone(pool.get(_make_op("missing"))) + + def test_release_drops_data(self): + pool = BufferPool() + op = _make_op("x") + pool.put(op, "data_x") + released = pool.release(op) + self.assertTrue(released) + self.assertIsNone(pool.get(op)) + self.assertEqual(pool.active_count, 0) + self.assertEqual(pool.total_released, 1) + + def test_release_missing_returns_false(self): + pool = BufferPool() + self.assertFalse(pool.release(_make_op("missing"))) + + def test_release_all(self): + pool = BufferPool() + ops = [_make_op(f"op{i}") for i in range(3)] + for i, op in enumerate(ops): + pool.put(op, f"data_{i}") + + released = pool.release_all() + self.assertEqual(set(released), set(ops)) + self.assertEqual(pool.active_count, 0) + self.assertEqual(pool.total_released, 3) + + def test_put_overwrites_existing(self): + pool = BufferPool() + op = _make_op("x") + pool.put(op, "old_data") + pool.put(op, "new_data") + self.assertEqual(pool.get(op), "new_data") + self.assertEqual(pool.active_count, 1) + + + +# --------------------------------------------------------------------------- +# Integration tests +# --------------------------------------------------------------------------- + +class TestBufferPoolIntegration(RuntimeTest): + + def test_evaluate_matches_baseline(self): + """Buffer-managed evaluate produces same results as skrub baseline.""" + pred_opt = simple_pipeline() + self.compare_evaluate(pred_opt) + + def test_grid_search_runs(self): + """Grid search with buffer manager completes without error.""" + pred_opt = simple_pipeline() + results = grid_search(pred_opt, cv=2) + self.assertIsNotNone(results) + + +if __name__ == "__main__": + unittest.main() diff --git a/stratum/tests/runtime/test_scheduler.py b/stratum/tests/runtime/test_scheduler.py new file mode 100644 index 00000000..a238dd35 --- /dev/null +++ b/stratum/tests/runtime/test_scheduler.py @@ -0,0 +1,124 @@ +import unittest + +from stratum.tests.runtime.runtime_test_utils import _make_linear_dag, _make_diamond_dag, _make_op +from stratum.optimizer._input_release_planning import plan_input_releases + + +def _plan_release_schedule(ops, split_pos, flagged_ops, pinned_ops=None): + """Run plan_input_releases and return the computed release_after dict.""" + pinned = set(pinned_ops) if pinned_ops else set() + plan_input_releases(ops, pinned) + return {op: op.release_after for op in ops} + + +# --------------------------------------------------------------------------- +# Release schedule tests (exercises Scheduler.plan) +# --------------------------------------------------------------------------- + +class TestReleaseSchedule(unittest.TestCase): + + def test_linear_no_split(self): + ops = _make_linear_dag() # A -> B -> C + schedule = _plan_release_schedule(ops, split_pos=None, flagged_ops=[]) + # After B processes: A's single consumer (B) is done -> release A + self.assertEqual(schedule[ops[1]], [ops[0]]) + # After C processes: B's single consumer (C) is done -> release B + self.assertEqual(schedule[ops[2]], [ops[1]]) + # A has no inputs, nothing to release + self.assertEqual(schedule[ops[0]], []) + + def test_linear_with_split_pinned(self): + ops = _make_linear_dag() # A -> B -> C + # Split at B (pos=1). A is pinned (feeds B which is post-split). + schedule = _plan_release_schedule(ops, split_pos=1, flagged_ops=[], pinned_ops=[ops[0]]) + # After B: A is pinned, so NOT released + self.assertEqual(schedule[ops[1]], []) + # After C: B released (not pinned, single consumer done) + self.assertEqual(schedule[ops[2]], [ops[1]]) + + def test_diamond_no_split(self): + a, b, c, d = _make_diamond_dag() # A -> {B, C} -> D + ops = [a, b, c, d] + schedule = _plan_release_schedule(ops, split_pos=None, flagged_ops=[]) + # A has 2 outputs (B, C). After B: remaining=1, no release. + self.assertEqual(schedule[b], []) + # After C: A's remaining hits 0 -> release A + self.assertEqual(schedule[c], [a]) + # After D: B and C both have 1 consumer (D), both released + self.assertCountEqual(schedule[d], [b, c]) + + def test_diamond_with_split(self): + a, b, c, d = _make_diamond_dag() + # Linearized: [A, C, B, D], split_pos=2 (B is split). + # A feeds B (post-split) and C (pre-split) -> A is pinned. + # C feeds D (post-split) -> C is pinned. + ops = [a, c, b, d] + schedule = _plan_release_schedule(ops, split_pos=2, flagged_ops=[], pinned_ops=[a, c]) + # After C: A is pinned, not released + self.assertEqual(schedule[c], []) + # After B: A is pinned, not released + self.assertEqual(schedule[b], []) + # After D: B released (not pinned), C pinned (not released) + self.assertEqual(schedule[d], [b]) + + def test_flagged_ops_not_pinned(self): + ops = _make_linear_dag() # A, B, C + # Split at C (pos=2), A is flagged (re-executed, not pinned). + # B feeds C (post-split) -> B is pinned. + schedule = _plan_release_schedule(ops, split_pos=2, flagged_ops=[ops[0]], pinned_ops=[ops[1]]) + # After B: A released (not pinned, single consumer done) + self.assertEqual(schedule[ops[1]], [ops[0]]) + # After C: B is pinned, not released + self.assertEqual(schedule[ops[2]], []) + + def test_complex_dag_with_non_descendant_branches(self): + """DAG: A -> B (split), A -> C, C -> E, E -> D, B -> D + Linearized: [A, C, E, B, D], split_pos=3. + """ + a = _make_op("A") + b = _make_op("B") + c = _make_op("C") + e = _make_op("E") + d = _make_op("D") + + a.outputs = [b, c] + b.inputs = [a] + b.outputs = [d] + c.inputs = [a] + c.outputs = [e] + e.inputs = [c] + e.outputs = [d] + d.inputs = [b, e] + + ops = [a, c, e, b, d] + # A feeds B (post-split) -> pinned. E feeds D (post-split) -> pinned. + # C only feeds E (pre-split) -> not pinned. + schedule = _plan_release_schedule(ops, split_pos=3, flagged_ops=[], pinned_ops=[a, e]) + # After C: A is pinned, not released + self.assertEqual(schedule[c], []) + # After E: C released (single consumer done, not pinned) + self.assertEqual(schedule[e], [c]) + # After B: A is pinned, not released + self.assertEqual(schedule[b], []) + # After D: B released (not pinned), E pinned (not released) + self.assertEqual(schedule[d], [b]) + + def test_mixed_pre_post_consumers(self): + """Pre-split op with both pre-split and post-split consumers.""" + a = _make_op("A") + b = _make_op("B") + c = _make_op("C") + a.outputs = [b, c] + b.inputs = [a] + c.inputs = [a] + ops = [a, b, c] + # Split at pos=2 (C is post-split). A feeds C -> pinned. + schedule = _plan_release_schedule(ops, split_pos=2, flagged_ops=[], pinned_ops=[a]) + # After B: A has 2 consumers, remaining=1, not released (also pinned) + self.assertEqual(schedule[b], []) + # After C: A's remaining hits 0 but pinned -> not released + self.assertEqual(schedule[c], []) + + +if __name__ == "__main__": + unittest.main()