Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions stratum/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class _Flags:
scheduler: bool = False
stats: bool = False # TODO if we want to use that flag on other runtimes we need to set envirenment variable as well
stats_top_k: int = 20
debug_graph: bool = False,
open_graph: bool = False,
cse: bool = True,
DEBUG: bool = False
Expand All @@ -45,17 +46,18 @@ class _Flags:
FLAGS = _Flags()

def set_config(rust_backend: bool | None = None,
num_threads: int | None = None,
debug_timing: bool | None = None,
allow_patch: bool | None = None,
stats: bool | None = None,
stats_top_k: int | None = None,
scheduler: bool | None = None,
open_graph: bool | None = None,
DEBUG: bool | None = None,
force_polars: bool | None = None,
cse: bool = True,
fast_dataops_convert: bool = True) -> None:
num_threads: int | None = None,
debug_timing: bool | None = None,
allow_patch: bool | None = None,
stats: bool | None = None,
stats_top_k: int | None = None,
scheduler: bool | None = None,
debug_graph: bool = False,
open_graph: bool | None = None,
DEBUG: bool | None = None,
force_polars: bool | None = None,
cse: bool = True,
fast_dataops_convert: bool = True) -> None:
"""Runtime toggles (synced env for Rust to read).

Parameter:
Expand Down Expand Up @@ -115,6 +117,8 @@ def set_config(rust_backend: bool | None = None,
if not (isinstance(stats_top_k, int) and stats_top_k >= 0):
raise ValueError("stats_top_k must be an int >= 0")
FLAGS.stats_top_k = int(stats_top_k)
if debug_graph is not None:
FLAGS.debug_graph = bool(debug_graph)
if open_graph is not None:
FLAGS.open_graph = bool(open_graph)
if DEBUG is not None:
Expand All @@ -141,6 +145,7 @@ def get_config() -> dict:
"scheduler": FLAGS.scheduler,
"stats": FLAGS.stats,
"stats_top_k": FLAGS.stats_top_k,
"debug_graph": FLAGS.debug_graph,
"open_graph": FLAGS.open_graph,
"DEBUG" : FLAGS.DEBUG,
"force_polars": FLAGS.force_polars,
Expand Down
2 changes: 1 addition & 1 deletion stratum/optimizer/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.algebraic_rewrite_config = algebraic_rewrite_config

def _debug_show_graph(root: Op, name: str):
if FLAGS.DEBUG:
if FLAGS.debug_graph:
show_graph(root, name)

def optimize(dag_root: DataOp, config: OptConfig = None):
Expand Down
4 changes: 2 additions & 2 deletions stratum/tests/application/test_multi_level_choice_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_application(self):
df.to_csv(os.path.join(tmp_path, "data.csv"), index=False)
preds = define_pipeline(os.path.join(tmp_path, "data.csv"))
scorer = make_scorer(r2_score)
with st.config(DEBUG=True, open_graph=False, scheduler=True, rust_backend=False):
with st.config(DEBUG=True, debug_graph=False, scheduler=True, rust_backend=False):
search = preds.skb.make_grid_search(fitted=True, cv = 2, scoring=scorer)
self.assertIsNotNone(search.results_)
self.assertGreater(len(search.results_), 0)
Expand All @@ -130,7 +130,7 @@ def test_application_polars(self):
df.to_csv(os.path.join(tmp_path, "data.csv"), index=False)
preds = define_pipeline(os.path.join(tmp_path, "data.csv"))
scorer = make_scorer(r2_score)
with st.config(DEBUG=True, open_graph=False, scheduler=True, rust_backend=False, force_polars=True):
with st.config(DEBUG=False, open_graph=False, scheduler=True, rust_backend=False, force_polars=True):
search = preds.skb.make_grid_search(fitted=True, cv = 2, scoring=scorer)
self.assertIsNotNone(search.results_)
self.assertGreater(len(search.results_), 0)
2 changes: 1 addition & 1 deletion stratum/tests/logical_optimizer/test_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_choice_unrolling(self):
out = optimize(t7, OptConfig(cse=True, unroll_choices=False))
root = out[-1]
out = choice_unrolling(root)
with config(open_graph=False):
if graph:
show_graph(out, filename='choice_unrolling')


Expand Down
Loading