Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions benchmarks/runtime/buffer_pool_input_release.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import sys

from sklearn.model_selection import ShuffleSplit

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import stratum as skrub
import numpy as np
import pandas as pd
from time import sleep, perf_counter
from sklearn.dummy import DummyRegressor
from utils.memory_consumption_tracker import MemoryTracker
import matplotlib.pyplot as plt
import argparse
import logging
import polars as pl
logging.basicConfig(level=logging.INFO)
logging.getLogger("stratum").setLevel(logging.DEBUG)

def dummy_func(x, t: float=0.1):
indices = np.arange(len(x))
if isinstance(x, pl.DataFrame) or isinstance(x, pl.Series):
out = x[indices]
else:
out = x.iloc[indices]
sleep(t)
return out

def main(use_skrub: bool, polars: bool):
with skrub.config_context(eager_data_ops=False):
n = 1_000_000
if not os.path.exists(f"input_{n}.csv"):
cols = ["a", "b", "c", "d", "e", "f", "g", "h", "y"]
df = pd.DataFrame({col: np.random.random(n) for col in cols}, dtype=np.float64)

print(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2} MB")

df.to_csv(f"input_{n}.csv", index=False)
del df
print("CSV written")

df = skrub.as_data_op(f"input_{n}.csv").skb.apply_func(pd.read_csv)
X = df.drop("y", axis=1).skb.mark_as_X()
y = df["y"]
y = y.skb.apply_func(dummy_func, t=1.0).skb.mark_as_y()

for i in range(5):
X = X.skb.apply_func(dummy_func, t=0.3)
model = DummyRegressor()

pred = X.skb.apply(model, y=y)
cv = ShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
tracker = MemoryTracker(mode="process", interval_sec=0.02)
tracker.start()

t0 = perf_counter()
try:
with skrub.config(scheduler=not use_skrub, stats=20, DEBUG=False, force_polars=polars):
search = pred.skb.make_grid_search(cv=cv, n_jobs=1, scoring="r2", fitted=True, refit=False)
finally:
samples = tracker.stop()
t1 = perf_counter()

csv_path = os.path.join(os.path.dirname(__file__), f"memory_usage_{'skrub' if use_skrub else 'stratum'}.csv")
tracker.write_csv(csv_path)

print(f"Time taken: {t1 - t0:.2f}s")
print(search.results_)
plot_memory(csv_path)


def plot_memory(csv_path: str):
data = pd.read_csv(csv_path)
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(data["time_sec"], data["rss_mb"], linewidth=1.5)
ax.set_xlabel("Time (s)")
ax.set_ylabel("RSS (MB)")
ax.set_title("Buffer Pool Benchmark - Memory Usage")
ax.grid(True, alpha=0.3)
fig.tight_layout()

plot_path = csv_path.replace(".csv", ".pdf")
fig.savefig(plot_path, dpi=150)
print(f"Plot saved to {plot_path}")
plt.close(fig)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--skrub", action="store_true")
parser.add_argument("--polars", action="store_true")
args = parser.parse_args()
main(use_skrub=args.skrub, polars=args.polars)
158 changes: 158 additions & 0 deletions benchmarks/utils/memory_consumption_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Reusable RSS memory tracker that runs in a separate process.

Usage::

from memory_tracker import MemoryTracker

tracker = MemoryTracker(dump_path="mem.csv")
tracker.start()
try:
... # workload
finally:
samples = tracker.stop() # list of (wall_sec, rss_mb)
tracker.write_csv("memory_usage.csv", t0=start_time)
"""

from __future__ import annotations

import os
import signal
from multiprocessing import Event, Manager, Process
from time import perf_counter
from typing import Literal

import psutil


def _flush_samples(samples_list, dump_path: str, flushed_count: int) -> int:
new_samples = list(samples_list[flushed_count:])
if not new_samples:
return flushed_count
mode = "a" if flushed_count > 0 else "w"
with open(dump_path, mode) as f:
if flushed_count == 0:
f.write("time_sec,rss_mb\n")
for ts, rss_mb in new_samples:
f.write(f"{ts:.6f},{rss_mb:.2f}\n")
return flushed_count + len(new_samples)


def _tracker_loop(
pid: int,
stop_event,
samples_list,
interval_sec: float,
dump_path: str,
flush_every: int,
) -> None:
flushed = 0

def _sigterm_handler(signum, frame):
nonlocal flushed
flushed = _flush_samples(samples_list, dump_path, flushed)
raise SystemExit(1)

signal.signal(signal.SIGTERM, _sigterm_handler)

sample_count = 0

def _sample(t, rss_mb):
nonlocal sample_count, flushed
samples_list.append((t, rss_mb))
sample_count += 1
if sample_count % flush_every == 0:
flushed = _flush_samples(samples_list, dump_path, flushed)

if pid != -1:
try:
proc = psutil.Process(pid)
except psutil.NoSuchProcess:
return
while not stop_event.is_set():
try:
rss_bytes = proc.memory_info().rss
except (psutil.NoSuchProcess, psutil.AccessDenied):
break
_sample(perf_counter(), rss_bytes / (1024 * 1024))
stop_event.wait(interval_sec)
else:
parent = psutil.Process(os.getppid())
while not stop_event.is_set():
if not parent.is_running():
break
_sample(perf_counter(), psutil.virtual_memory().used / (1024 * 1024))
stop_event.wait(interval_sec)

_flush_samples(samples_list, dump_path, flushed)


class MemoryTracker:
"""Spawn a side-car process that polls RSS of the current (or system) memory.

Parameters
----------
mode : "process" or "system"
"process" tracks the calling process's RSS.
"system" tracks total system memory used (useful with multi-process workloads).
interval_sec : float
Polling interval in seconds.
live_dump_path : str or None
If set, samples are incrementally flushed to this CSV while tracking.
flush_every : int
Number of samples between live flushes.
"""

def __init__(
self,
*,
mode: Literal["process", "system"] = "process",
interval_sec: float = 0.1,
live_dump_path: str | None = "memory_usage_live.csv",
flush_every: int = 50,
):
self.mode = mode
self.interval_sec = interval_sec
self.live_dump_path = live_dump_path or "memory_usage_live.csv"
self.flush_every = flush_every

self._manager = Manager()
self._samples = self._manager.list()
self._stop = Event()
self._process: Process | None = None
self._t0: float | None = None

def start(self) -> None:
pid = os.getpid() if self.mode == "process" else -1
self._t0 = perf_counter()
self._process = Process(
target=_tracker_loop,
args=(pid, self._stop, self._samples, self.interval_sec,
self.live_dump_path, self.flush_every),
)
self._process.start()

def stop(self, timeout: float = 2.0) -> list[tuple[float, float]]:
"""Signal the tracker to stop and return collected samples."""
self._stop.set()
if self._process is not None:
self._process.join(timeout=timeout)
return list(self._samples)

@property
def t0(self) -> float:
if self._t0 is None:
raise RuntimeError("Tracker has not been started yet")
return self._t0

def write_csv(self, path: str, *, t0: float | None = None) -> None:
"""Write all samples to *path* with wall-clock times relative to *t0*."""
t0 = t0 if t0 is not None else self._t0
if t0 is None:
raise RuntimeError("t0 not available; pass it explicitly or call start() first")
samples = list(self._samples)
if not samples:
return
with open(path, "w") as f:
f.write("time_sec,rss_mb\n")
for ts, rss_mb in samples:
f.write(f"{ts - t0:.4f},{rss_mb:.2f}\n")
8 changes: 4 additions & 4 deletions stratum/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def grid_search(dag: DataOp, cv=None, scoring=None, return_predictions=False, en
env = dag.skb.get_data()
for k, v in env_extra.items():
env[k] = v
dag = optimize(dag)
sched = SequentialScheduler(dag, show_stats, env=env, t0=t0)
linearized_dag, split_pos, flagged_ops = optimize(dag)
sched = SequentialScheduler(linearized_dag, split_pos, flagged_ops, show_stats, env=env, t0=t0)

preds = sched.grid_search(cv, scoring, return_predictions)

Expand All @@ -34,5 +34,5 @@ def grid_search(dag: DataOp, cv=None, scoring=None, return_predictions=False, en

def evaluate(dag: DataOp, seed: int = 42, test_size = 0.2, cse: bool = False):
"""Evaluate a DataOp DAG with train/test split."""
ops_ordered = optimize(dag)
return SequentialScheduler(ops_ordered).evaluate(seed, test_size)
linearized_dag, split_pos, flagged_ops = optimize(dag)
return SequentialScheduler(linearized_dag, split_pos, flagged_ops).evaluate(seed, test_size)
64 changes: 64 additions & 0 deletions stratum/optimizer/_input_release_planning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Input release planning for the linearized Op DAG.

After linearization, each op produces an intermediate buffer stored in the
BufferPool. Release planning decides *when* each buffer can be freed so
that peak memory stays low while correctness is preserved.

How it works:
- Each op starts with a consumer count equal to ``len(op.outputs)`` (min 1).
- Walking the linearized order, every time an op appears as an input of
another op, its remaining count is decremented.
- When the count hits zero the buffer is scheduled for release right after
the current op finishes (set via ``op.release_after``).
- Pinned ops (pre-split ops that feed post-split / re-executed ops) are
never released by the schedule; they persist across CV folds.
- At execution time ``Op.release_inputs()`` calls ``buffers.release()`` for
every entry in its ``release_after`` list.
"""
from __future__ import annotations

from stratum.optimizer.ir._ops import Op

import logging
logger = logging.getLogger(__name__)


def compute_pinned_ops(
linearized_dag: list[Op],
split_pos: int | None,
recompute_ops: list[Op],
) -> set[Op]:
"""Return pre-split ops whose buffers must persist across CV folds."""
if split_pos is None:
return set()

pinned: set[Op] = set()
re_executed = set(linearized_dag[split_pos:]) | set(recompute_ops)
for op in linearized_dag[:split_pos]:
if op not in re_executed:
for out_op in op.outputs:
if out_op in re_executed:
pinned.add(op)
break
return pinned


def plan_input_releases(
linearized_dag: list[Op],
pinned_ops: set[Op],
) -> None:
"""Set ``op.release_after`` for every op in the linearized DAG."""
remaining = {op: max(len(op.outputs), 1) for op in linearized_dag}

for op in linearized_dag:
release = []
for in_op in op.inputs:
remaining[in_op] -= 1
if remaining[in_op] <= 0 and in_op not in pinned_ops:
release.append(in_op)
op.release_after = release

logger.debug(
f"Release planning done: pinned={len(pinned_ops)} ops, "
f"total={len(linearized_dag)} ops"
)
59 changes: 59 additions & 0 deletions stratum/optimizer/_linearization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from skrub._data_ops._data_ops import EvalMode
from stratum.optimizer.ir._dataframe_ops import SplitOp
from stratum.optimizer._op_utils import compute_graph_node_indegree
from stratum.optimizer.ir._ops import ImplOp, Op

import logging
logger = logging.getLogger(__name__)


def linearize_dag(dag_sink: Op) -> tuple[list[Op], int | None, list[Op]]:
"""Topologically sort a DAG and enforce the split invariant.

Single-pass DFS that defers the split op on the stack: whenever the
split op is ready but other ops are too, prefer the non-split ops.
This naturally ensures all non-descendants are processed first.

Guarantees that every op at index >= split_pos is a descendant of the
split op (i.e. in its subtree). This makes downstream scheduling
trivial: ops[:split_pos] are pre-split, ops[split_pos:] are post-split.

Returns:
linearized_dag: Linearized list of ops with the split invariant.
split_pos: Index of the SplitOp in linearized_dag, or None if absent.
flagged_ops: Ops flagged for recomputation (ImplOps with EvalMode).
"""
indegree, sources = compute_graph_node_indegree(dag_sink)

linearized_ops = []
flagged_ops = []
stack = list(sources)
split_pos = None
i = 0
while stack:
if stack[-1].is_split_op and len(stack) > 1:
# Defer split op: pop something else instead
op = stack.pop(-2)
else:
op = stack.pop()

if op.is_split_op:
split_pos = i
if isinstance(op, ImplOp) and isinstance(op.skrub_impl, EvalMode):
flagged_ops.append(op)

linearized_ops.append(op)
for out_op in op.outputs:
if out_op not in indegree:
raise RuntimeError(
f"Encountered op {out_op} which should not exist in the DAG. "
f"Probably due to a buggy rewrite, which did not update its inputs / outputs correctly."
)
indegree[out_op] -= 1
if indegree[out_op] == 0:
stack.append(out_op)
i += 1

return linearized_ops, split_pos, flagged_ops
Loading
Loading