From 746c95a6cdac65ec7275ad2e6d86a1964bfbe97d Mon Sep 17 00:00:00 2001 From: e-strauss Date: Thu, 2 Apr 2026 16:35:56 +0200 Subject: [PATCH] [Refactor] reorganize logical optimizer into optimizer.ir package - move ops, dataframe ops, and numeric ops into the new stratum.optimizer.ir namespace - update runtime, benchmarks, and tests to import from the optimizer package - refresh README and benchmark paths to reflect the new optimizer layout --- README.md | 30 ++++++++++--------- .../end-to-end/20newsgroups.py | 2 +- .../skrubified_pipelines.py | 2 +- .../end-to-end/california-housing/bar_plot.py | 2 +- .../skrubified_merged_pipelines.py | 2 +- .../end-to-end/plot_20newsgroup_results.py | 2 +- stratum/_api.py | 2 +- .../__init__.py | 0 .../_algebraic_rewrites.py | 6 ++-- .../{logical_optimizer => optimizer}/_cse.py | 0 .../_op_comparison.py | 0 .../_op_utils.py | 2 +- .../_optimize.py | 8 ++--- stratum/optimizer/ir/__init__.py | 0 .../ir}/_dataframe_ops.py | 4 +-- .../ir}/_numeric_ops.py | 6 ++-- .../ir}/_ops.py | 2 +- .../rewrite_ideas.md | 0 stratum/runtime/_scheduler.py | 6 ++-- .../test_multi_level_choice_graph.py | 2 +- .../algebraic_rewrites/test_numeric.py | 8 ++--- .../test_check_for_equivalence.py | 2 +- .../test_check_for_inequality.py | 2 +- .../op_comparisons/test_hash_equivalence.py | 2 +- .../op_comparisons/test_update.py | 2 +- stratum/tests/logical_optimizer/test_cse.py | 6 ++-- .../logical_optimizer/test_dataframe_ops.py | 8 ++--- .../logical_optimizer/test_numeric_ops.py | 2 +- .../tests/logical_optimizer/test_op_utils.py | 4 +-- stratum/tests/logical_optimizer/test_ops.py | 6 ++-- .../tests/logical_optimizer/test_optimize.py | 4 +-- 31 files changed, 62 insertions(+), 62 deletions(-) rename stratum/{logical_optimizer => optimizer}/__init__.py (100%) rename stratum/{logical_optimizer => optimizer}/_algebraic_rewrites.py (89%) rename stratum/{logical_optimizer => optimizer}/_cse.py (100%) rename stratum/{logical_optimizer => optimizer}/_op_comparison.py (100%) rename stratum/{logical_optimizer => optimizer}/_op_utils.py (98%) rename stratum/{logical_optimizer => optimizer}/_optimize.py (96%) create mode 100644 stratum/optimizer/ir/__init__.py rename stratum/{logical_optimizer => optimizer/ir}/_dataframe_ops.py (98%) rename stratum/{logical_optimizer => optimizer/ir}/_numeric_ops.py (90%) rename stratum/{logical_optimizer => optimizer/ir}/_ops.py (99%) rename stratum/{logical_optimizer => optimizer}/rewrite_ideas.md (100%) diff --git a/README.md b/README.md index 040a907e..5921a5fb 100644 --- a/README.md +++ b/README.md @@ -91,24 +91,26 @@ if __name__ == "__main__": ```bash stratum/ -├─ pyproject.toml # Project metadata + Python/Rust build config (maturin) +├─ pyproject.toml # Project metadata + Python/Rust build config (maturin) ├─ README.md ├─ LICENSE -├─ _rust/ # Rust crate (PyO3 extension) +├─ _rust/ # Rust crate (PyO3 extension) │ ├─ Cargo.toml -│ └─ src/lib.rs # Defines #[pymodule] fn _rust_backend_native(...) -└─ stratum/ # Python package - ├─ __init__.py # Façade over skrub + automatic patching - ├─ _config.py # set_config/get_config + runtime/env sync - ├─ _api.py # High-level grid search / evaluate helpers - ├─ _rust_backend.py # Python <-> Rust shim (re-exports native fns) - ├─ adapters/ # Public API (dispatch to Rust or fall back to skrub) - │ ├─ string_encoder.py # RustyStringEncoder +│ └─ src/lib.rs # Defines #[pymodule] fn _rust_backend_native(...) +└─ stratum/ # Python package + ├─ __init__.py # Façade over skrub + automatic patching + ├─ _config.py # set_config/get_config + runtime/env sync + ├─ _api.py # High-level grid search / evaluate helpers + ├─ _rust_backend.py # Python <-> Rust shim (re-exports native fns) + ├─ adapters/ # Public API (dispatch to Rust or fall back to skrub) + │ ├─ string_encoder.py # RustyStringEncoder │ └─ one_hot_encoder.py # RustyOneHotEncoder - ├─ logical_optimizer/ # DAG representation + logical rewrites - ├─ runtime/ # Schedulers and runtime execution - ├─ patching/ # Hooks that patch upstream skrub - └─ tests/ # Test suite + ├─ optimizer/ + │ ├─ ir/ # DAG representation + │ └─ _optimize.py # logical rewrites + ├─ runtime/ # Schedulers and runtime execution + ├─ patching/ # Hooks that patch upstream skrub + └─ tests/ # Test suite ``` --- diff --git a/benchmarks/logical_optimizer/end-to-end/20newsgroups.py b/benchmarks/logical_optimizer/end-to-end/20newsgroups.py index 1b688cc3..1a512a61 100644 --- a/benchmarks/logical_optimizer/end-to-end/20newsgroups.py +++ b/benchmarks/logical_optimizer/end-to-end/20newsgroups.py @@ -5,7 +5,7 @@ from sklearn.linear_model import Ridge, LinearRegression, LogisticRegression from sklearn.svm import LinearSVC -from stratum.logical_optimizer import apply_cse_on_skrub_ir +from stratum.optimizer import apply_cse_on_skrub_ir from stratum.api.gridsearch import grid_search import stratum as skrub diff --git a/benchmarks/logical_optimizer/end-to-end/bike-sharing-demand/skrubified_pipelines.py b/benchmarks/logical_optimizer/end-to-end/bike-sharing-demand/skrubified_pipelines.py index 8e918eef..509ccc0a 100644 --- a/benchmarks/logical_optimizer/end-to-end/bike-sharing-demand/skrubified_pipelines.py +++ b/benchmarks/logical_optimizer/end-to-end/bike-sharing-demand/skrubified_pipelines.py @@ -9,7 +9,7 @@ from sklearn.metrics import mean_squared_log_error, make_scorer import time -from stratum.logical_optimizer import apply_cse_on_skrub_ir +from stratum.optimizer import apply_cse_on_skrub_ir t0 = time.time() diff --git a/benchmarks/logical_optimizer/end-to-end/california-housing/bar_plot.py b/benchmarks/logical_optimizer/end-to-end/california-housing/bar_plot.py index eeabd1b3..07b71aac 100644 --- a/benchmarks/logical_optimizer/end-to-end/california-housing/bar_plot.py +++ b/benchmarks/logical_optimizer/end-to-end/california-housing/bar_plot.py @@ -2,7 +2,7 @@ import pandas as pd import numpy as np -base_path = "benchmarks/logical_optimizer/end-to-end/california-housing/" +base_path = "benchmarks/optimizer/end-to-end/california-housing/" data = pd.read_csv(base_path + "california_housing_pipelines_benchmark.csv", sep=";") data["time"] = data["time"].apply(np.round, decimals=2) diff --git a/benchmarks/logical_optimizer/end-to-end/california-housing/skrubified_merged_pipelines.py b/benchmarks/logical_optimizer/end-to-end/california-housing/skrubified_merged_pipelines.py index 82b4e4b8..2f616407 100644 --- a/benchmarks/logical_optimizer/end-to-end/california-housing/skrubified_merged_pipelines.py +++ b/benchmarks/logical_optimizer/end-to-end/california-housing/skrubified_merged_pipelines.py @@ -5,7 +5,7 @@ from sklearn.preprocessing import StandardScaler from sklearn.linear_model import ElasticNet, Lasso, LinearRegression, Ridge -from stratum.logical_optimizer import apply_cse_on_skrub_ir +from stratum.optimizer import apply_cse_on_skrub_ir from stratum.api.gridsearch import grid_search from time import time diff --git a/benchmarks/logical_optimizer/end-to-end/plot_20newsgroup_results.py b/benchmarks/logical_optimizer/end-to-end/plot_20newsgroup_results.py index de96e41a..fb3907ee 100644 --- a/benchmarks/logical_optimizer/end-to-end/plot_20newsgroup_results.py +++ b/benchmarks/logical_optimizer/end-to-end/plot_20newsgroup_results.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt import numpy as np -base_path = "benchmarks/logical_optimizer/end-to-end/" +base_path = "benchmarks/optimizer/end-to-end/" data = pd.read_csv(base_path + 'bench_cse_tfidf_gridsearch.csv') data["total"] = data["total"].apply(np.round, decimals=2) diff --git a/stratum/_api.py b/stratum/_api.py index 1960a9c2..3c025afa 100644 --- a/stratum/_api.py +++ b/stratum/_api.py @@ -2,7 +2,7 @@ from skrub import DataOp from stratum._config import FLAGS -from stratum.logical_optimizer._optimize import optimize +from stratum.optimizer._optimize import optimize from stratum.runtime._scheduler import SequentialScheduler from time import perf_counter diff --git a/stratum/logical_optimizer/__init__.py b/stratum/optimizer/__init__.py similarity index 100% rename from stratum/logical_optimizer/__init__.py rename to stratum/optimizer/__init__.py diff --git a/stratum/logical_optimizer/_algebraic_rewrites.py b/stratum/optimizer/_algebraic_rewrites.py similarity index 89% rename from stratum/logical_optimizer/_algebraic_rewrites.py rename to stratum/optimizer/_algebraic_rewrites.py index 8340948c..0d57341a 100644 --- a/stratum/logical_optimizer/_algebraic_rewrites.py +++ b/stratum/optimizer/_algebraic_rewrites.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from typing import Any, Callable -from stratum.logical_optimizer._numeric_ops import NumericOp -from stratum.logical_optimizer._op_utils import topological_iterator -from stratum.logical_optimizer._numeric_ops import NumericOpType +from stratum.optimizer.ir._numeric_ops import NumericOp +from stratum.optimizer._op_utils import topological_iterator +from stratum.optimizer.ir._numeric_ops import NumericOpType RewriteFn = Callable[[NumericOp, Any], Any] diff --git a/stratum/logical_optimizer/_cse.py b/stratum/optimizer/_cse.py similarity index 100% rename from stratum/logical_optimizer/_cse.py rename to stratum/optimizer/_cse.py diff --git a/stratum/logical_optimizer/_op_comparison.py b/stratum/optimizer/_op_comparison.py similarity index 100% rename from stratum/logical_optimizer/_op_comparison.py rename to stratum/optimizer/_op_comparison.py diff --git a/stratum/logical_optimizer/_op_utils.py b/stratum/optimizer/_op_utils.py similarity index 98% rename from stratum/logical_optimizer/_op_utils.py rename to stratum/optimizer/_op_utils.py index b172f651..101725a8 100644 --- a/stratum/logical_optimizer/_op_utils.py +++ b/stratum/optimizer/_op_utils.py @@ -2,7 +2,7 @@ from collections import deque from typing import Iterator from graphviz import Digraph -from stratum.logical_optimizer._ops import DATA_OP_PLACEHOLDER, Op, ChoiceOp +from stratum.optimizer.ir._ops import DATA_OP_PLACEHOLDER, Op, ChoiceOp from stratum._config import get_config import os from dataclasses import dataclass diff --git a/stratum/logical_optimizer/_optimize.py b/stratum/optimizer/_optimize.py similarity index 96% rename from stratum/logical_optimizer/_optimize.py rename to stratum/optimizer/_optimize.py index 1570807d..35fd4bcd 100644 --- a/stratum/logical_optimizer/_optimize.py +++ b/stratum/optimizer/_optimize.py @@ -3,16 +3,16 @@ from skrub._data_ops._subsampling import SubsamplePreviews from collections import deque from ._cse import apply_cse -from ._dataframe_ops import rewrite_dataframe_ops, group_dataframe_ops,add_splitting_op -from ._numeric_ops import to_numeric_op -from ._ops import ChoiceOp, ImplOp, Op, SearchEvalOp, as_op +from stratum.optimizer.ir._dataframe_ops import rewrite_dataframe_ops, group_dataframe_ops,add_splitting_op +from stratum.optimizer.ir._numeric_ops import to_numeric_op +from stratum.optimizer.ir._ops import ChoiceOp, ImplOp, Op, SearchEvalOp, as_op from ._op_utils import clone_sub_dag, find_choice_naive, replace_op_in_outputs, show_graph, topological_iterator from ._algebraic_rewrites import algebraic_rewrites from stratum.utils._skrub_graph import build_graph from time import perf_counter import logging from stratum._config import FLAGS -from stratum.logical_optimizer._algebraic_rewrites import AlgebraicRewritesConfig +from stratum.optimizer._algebraic_rewrites import AlgebraicRewritesConfig logger = logging.getLogger(__name__) EVAL_OP_ENABLED = False diff --git a/stratum/optimizer/ir/__init__.py b/stratum/optimizer/ir/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stratum/logical_optimizer/_dataframe_ops.py b/stratum/optimizer/ir/_dataframe_ops.py similarity index 98% rename from stratum/logical_optimizer/_dataframe_ops.py rename to stratum/optimizer/ir/_dataframe_ops.py index db27ebf7..a6f4c063 100644 --- a/stratum/logical_optimizer/_dataframe_ops.py +++ b/stratum/optimizer/ir/_dataframe_ops.py @@ -1,8 +1,8 @@ -from stratum.logical_optimizer._ops import DATA_OP_PLACEHOLDER, BaseEstimatorOp, BinOp, CallOp, GetAttrOp, GetItemOp, MethodCallOp, Op, ValueOp, VariableOp +from stratum.optimizer.ir._ops import DATA_OP_PLACEHOLDER, BaseEstimatorOp, BinOp, CallOp, GetAttrOp, GetItemOp, MethodCallOp, Op, ValueOp, VariableOp from pandas import DataFrame import pandas as pd import polars as pl -from stratum.logical_optimizer._op_utils import topological_iterator +from stratum.optimizer._op_utils import topological_iterator from stratum._config import FLAGS from skrub._data_ops._data_ops import DataOp import logging diff --git a/stratum/logical_optimizer/_numeric_ops.py b/stratum/optimizer/ir/_numeric_ops.py similarity index 90% rename from stratum/logical_optimizer/_numeric_ops.py rename to stratum/optimizer/ir/_numeric_ops.py index 6db8aa72..578f7a45 100644 --- a/stratum/logical_optimizer/_numeric_ops.py +++ b/stratum/optimizer/ir/_numeric_ops.py @@ -1,7 +1,5 @@ -from stratum.logical_optimizer._ops import CallOp, Op, ValueOp -from pandas import DataFrame -from stratum.logical_optimizer._dataframe_ops import DataSourceOp -from stratum.logical_optimizer._op_utils import topological_iterator +from stratum.optimizer.ir._ops import CallOp, Op +from stratum.optimizer._op_utils import topological_iterator import numpy as np from enum import Enum diff --git a/stratum/logical_optimizer/_ops.py b/stratum/optimizer/ir/_ops.py similarity index 99% rename from stratum/logical_optimizer/_ops.py rename to stratum/optimizer/ir/_ops.py index dfe71245..1824fdb6 100644 --- a/stratum/logical_optimizer/_ops.py +++ b/stratum/optimizer/ir/_ops.py @@ -543,7 +543,7 @@ def as_op(data_op: DataOp): elif isinstance(impl, Var): return_op = VariableOp(name=impl.name, value=impl.value) elif isinstance(impl, Concat): - from stratum.logical_optimizer._dataframe_ops import ConcatOp + from stratum.optimizer.ir._dataframe_ops import ConcatOp return_op = ConcatOp(first=impl.first, others=impl.others, axis=impl.axis) else: return_op = ImplOp(skrub_impl=impl, name=data_op.__skrub_short_repr__()) diff --git a/stratum/logical_optimizer/rewrite_ideas.md b/stratum/optimizer/rewrite_ideas.md similarity index 100% rename from stratum/logical_optimizer/rewrite_ideas.md rename to stratum/optimizer/rewrite_ideas.md diff --git a/stratum/runtime/_scheduler.py b/stratum/runtime/_scheduler.py index ada628af..a82d614e 100644 --- a/stratum/runtime/_scheduler.py +++ b/stratum/runtime/_scheduler.py @@ -3,9 +3,9 @@ from sklearn.model_selection import train_test_split, check_cv from sklearn.metrics._scorer import _Scorer, get_scorer from skrub._data_ops._data_ops import EvalMode -from stratum.logical_optimizer._dataframe_ops import SplitOp -from stratum.logical_optimizer._op_utils import topological_iterator -from stratum.logical_optimizer._ops import ImplOp, Op +from stratum.optimizer.ir._dataframe_ops import SplitOp +from stratum.optimizer._op_utils import topological_iterator +from stratum.optimizer.ir._ops import ImplOp, Op import polars as pl import logging diff --git a/stratum/tests/application/test_multi_level_choice_graph.py b/stratum/tests/application/test_multi_level_choice_graph.py index bca02997..d4852a1b 100644 --- a/stratum/tests/application/test_multi_level_choice_graph.py +++ b/stratum/tests/application/test_multi_level_choice_graph.py @@ -13,7 +13,7 @@ from xgboost import XGBRegressor from sklearn.base import BaseEstimator, TransformerMixin from sklearn.metrics import make_scorer, mean_squared_error, r2_score -from stratum.logical_optimizer._optimize import optimize +from stratum.optimizer._optimize import optimize class TargetEncoder(BaseEstimator, TransformerMixin): diff --git a/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py b/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py index d77b958a..d52f8055 100644 --- a/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py +++ b/stratum/tests/logical_optimizer/algebraic_rewrites/test_numeric.py @@ -1,10 +1,10 @@ import unittest import stratum as skrub import numpy as np -from stratum.logical_optimizer._optimize import optimize -from stratum.logical_optimizer._optimize import OptConfig -from stratum.logical_optimizer._algebraic_rewrites import AlgebraicRewritesConfig -from stratum.logical_optimizer._op_utils import topological_iterator +from stratum.optimizer._optimize import optimize +from stratum.optimizer._optimize import OptConfig +from stratum.optimizer._algebraic_rewrites import AlgebraicRewritesConfig +from stratum.optimizer._op_utils import topological_iterator class TestCSE(unittest.TestCase): diff --git a/stratum/tests/logical_optimizer/op_comparisons/test_check_for_equivalence.py b/stratum/tests/logical_optimizer/op_comparisons/test_check_for_equivalence.py index 6dba894e..a42c7a8b 100644 --- a/stratum/tests/logical_optimizer/op_comparisons/test_check_for_equivalence.py +++ b/stratum/tests/logical_optimizer/op_comparisons/test_check_for_equivalence.py @@ -4,7 +4,7 @@ from skrub import TableVectorizer import stratum as skrub -from stratum.logical_optimizer._op_comparison import equals_data_op +from stratum.optimizer._op_comparison import equals_data_op import pandas as pd # dummy function diff --git a/stratum/tests/logical_optimizer/op_comparisons/test_check_for_inequality.py b/stratum/tests/logical_optimizer/op_comparisons/test_check_for_inequality.py index af3d0b0f..2cf28845 100644 --- a/stratum/tests/logical_optimizer/op_comparisons/test_check_for_inequality.py +++ b/stratum/tests/logical_optimizer/op_comparisons/test_check_for_inequality.py @@ -4,7 +4,7 @@ from skrub import TableVectorizer import stratum as skrub -from stratum.logical_optimizer._op_comparison import equals_data_op +from stratum.optimizer._op_comparison import equals_data_op import pandas as pd # dummy function diff --git a/stratum/tests/logical_optimizer/op_comparisons/test_hash_equivalence.py b/stratum/tests/logical_optimizer/op_comparisons/test_hash_equivalence.py index 6016d373..81ffbc4e 100644 --- a/stratum/tests/logical_optimizer/op_comparisons/test_hash_equivalence.py +++ b/stratum/tests/logical_optimizer/op_comparisons/test_hash_equivalence.py @@ -4,7 +4,7 @@ from skrub import TableVectorizer import stratum as skrub -from stratum.logical_optimizer._op_comparison import equals_data_op, hash_data_op +from stratum.optimizer._op_comparison import equals_data_op, hash_data_op import pandas as pd # dummy function diff --git a/stratum/tests/logical_optimizer/op_comparisons/test_update.py b/stratum/tests/logical_optimizer/op_comparisons/test_update.py index 1c4cafc3..7b2c96f1 100644 --- a/stratum/tests/logical_optimizer/op_comparisons/test_update.py +++ b/stratum/tests/logical_optimizer/op_comparisons/test_update.py @@ -4,7 +4,7 @@ from sklearn.preprocessing import StandardScaler import stratum as skrub -from stratum.logical_optimizer._op_comparison import update_data_op +from stratum.optimizer._op_comparison import update_data_op import pandas as pd # dummy function diff --git a/stratum/tests/logical_optimizer/test_cse.py b/stratum/tests/logical_optimizer/test_cse.py index ff859dd4..f9c8a26f 100644 --- a/stratum/tests/logical_optimizer/test_cse.py +++ b/stratum/tests/logical_optimizer/test_cse.py @@ -1,7 +1,7 @@ from skrub._data_ops._evaluation import _Graph -from stratum.logical_optimizer import apply_cse_on_skrub_ir -from stratum.logical_optimizer._cse import CSETable -from stratum.logical_optimizer._optimize import topological_traverse +from stratum.optimizer import apply_cse_on_skrub_ir +from stratum.optimizer._cse import CSETable +from stratum.optimizer._optimize import topological_traverse import unittest import stratum as skrub import pandas as pd diff --git a/stratum/tests/logical_optimizer/test_dataframe_ops.py b/stratum/tests/logical_optimizer/test_dataframe_ops.py index 8ea45838..a4d507b9 100644 --- a/stratum/tests/logical_optimizer/test_dataframe_ops.py +++ b/stratum/tests/logical_optimizer/test_dataframe_ops.py @@ -9,13 +9,13 @@ import stratum as skrub from skrub._data_ops._data_ops import DataOp from stratum._config import FLAGS -from stratum.logical_optimizer._dataframe_ops import ( +from stratum.optimizer.ir._dataframe_ops import ( ApplyUDFOp, AssignOp, ConcatOp, DataSourceOp, DatetimeConversionOp, DropOp, GetAttrProjectionOp, GroupedDataframeOp, MetadataOp, ProjectionOp, SplitOp, rewrite_fuse_get_item_ops,) -from stratum.logical_optimizer._op_utils import topological_iterator -from stratum.logical_optimizer._ops import DATA_OP_PLACEHOLDER, GetItemOp, MethodCallOp, Op -from stratum.logical_optimizer._optimize import OptConfig, optimize as optimize_ +from stratum.optimizer._op_utils import topological_iterator +from stratum.optimizer.ir._ops import DATA_OP_PLACEHOLDER, GetItemOp, MethodCallOp, Op +from stratum.optimizer._optimize import OptConfig, optimize as optimize_ def optimize(dag, conf=None): diff --git a/stratum/tests/logical_optimizer/test_numeric_ops.py b/stratum/tests/logical_optimizer/test_numeric_ops.py index 14c65f3b..d6a73c27 100644 --- a/stratum/tests/logical_optimizer/test_numeric_ops.py +++ b/stratum/tests/logical_optimizer/test_numeric_ops.py @@ -3,7 +3,7 @@ import stratum as skrub import numpy as np from sklearn.dummy import DummyRegressor -from stratum.logical_optimizer._numeric_ops import NumericOp +from stratum.optimizer.ir._numeric_ops import NumericOp class TestNumericOps(unittest.TestCase): def setUp(self): diff --git a/stratum/tests/logical_optimizer/test_op_utils.py b/stratum/tests/logical_optimizer/test_op_utils.py index cc7436af..e5437647 100644 --- a/stratum/tests/logical_optimizer/test_op_utils.py +++ b/stratum/tests/logical_optimizer/test_op_utils.py @@ -1,8 +1,8 @@ #from curses import flash import unittest import stratum as skrub -from stratum.logical_optimizer._optimize import optimize as optimize_, OptConfig, choice_unrolling -from stratum.logical_optimizer._op_utils import show_graph, clone_sub_dag, topological_iterator, FLAGS +from stratum.optimizer._optimize import optimize as optimize_, OptConfig, choice_unrolling +from stratum.optimizer._op_utils import show_graph, clone_sub_dag, topological_iterator, FLAGS from stratum._config import config graph = False diff --git a/stratum/tests/logical_optimizer/test_ops.py b/stratum/tests/logical_optimizer/test_ops.py index 8fbeec84..b8fc9305 100644 --- a/stratum/tests/logical_optimizer/test_ops.py +++ b/stratum/tests/logical_optimizer/test_ops.py @@ -10,15 +10,15 @@ from sklearn.preprocessing import StandardScaler from skrub._data_ops._data_ops import DataOp -from stratum.logical_optimizer._op_utils import topological_iterator -from stratum.logical_optimizer._ops import ( +from stratum.optimizer._op_utils import topological_iterator +from stratum.optimizer.ir._ops import ( DATA_OP_PLACEHOLDER, BinOp, CallOp, DummyConfigManager, GetAttrOp, GetItemOp, ImplOp, MethodCallOp, Op, PlaceHolder, SearchEvalOp, ValueOp, VariableOp, check_estm_inputs, estimator_parallel_config, estm_supports_polars, process_estimator_task, process_transformer_task, remove_datops_from_args, ) -from stratum.logical_optimizer._optimize import optimize as optimize_ +from stratum.optimizer._optimize import optimize as optimize_ def _inp(val): diff --git a/stratum/tests/logical_optimizer/test_optimize.py b/stratum/tests/logical_optimizer/test_optimize.py index 9454a352..d0b6d157 100644 --- a/stratum/tests/logical_optimizer/test_optimize.py +++ b/stratum/tests/logical_optimizer/test_optimize.py @@ -1,5 +1,5 @@ -from stratum.logical_optimizer._op_utils import topological_iterator -from stratum.logical_optimizer._optimize import OptConfig, optimize +from stratum.optimizer._op_utils import topological_iterator +from stratum.optimizer._optimize import OptConfig, optimize import stratum as skrub import pandas as pd import unittest