Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions stratum/optimizer/ir/_dataframe_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from pandas import DataFrame
import pandas as pd
import polars as pl
import numpy as np
from stratum.optimizer._op_utils import topological_iterator
from stratum._config import FLAGS
from stratum.utils._utils import start_time, log_time
from skrub._data_ops._data_ops import DataOp
import logging
from time import perf_counter
from numpy import sin, cos
logger = logging.getLogger(__name__)

Expand All @@ -26,7 +26,7 @@ def __init__(self, data: DataFrame = None, file_path: str = None, _format: str =
self.file_path = file_path
self.read_args = read_args
self.read_kwargs = read_kwargs
self.is_dataframe_op = True
self.is_dataframe_op = format != "npy"

def process(self, mode: str, environment: dict, inputs: list):
if self.data is not None:
Expand All @@ -39,7 +39,12 @@ def process(self, mode: str, environment: dict, inputs: list):
if FLAGS.force_polars:
return pl.read_csv(file_path, *self.read_args, **self.read_kwargs)
else:
return pd.read_csv(file_path, *self.read_args, **self.read_kwargs)
if self.format == "csv":
return pd.read_csv(file_path, *self.read_args, **self.read_kwargs)
elif self.format == "npy":
return np.load(file_path, *self.read_args, **self.read_kwargs)
else:
raise ValueError(f"Unsupported format: {self.format}")

def clone(self):
raise ValueError(f"We should not clone DataSourceOp objects.")
Expand Down Expand Up @@ -292,6 +297,8 @@ def process(self, mode: str, environment: dict, inputs: list):
return (x.iloc[self.indices], y.iloc[self.indices])
elif isinstance(x, pl.DataFrame):
return (x[self.indices], y[self.indices])
elif isinstance(x, np.ndarray):
return (x[self.indices], y[self.indices])
else:
raise ValueError(f"Unsupported dataframe type: {type(x)}")

Expand Down Expand Up @@ -345,6 +352,9 @@ def extract_dataframe_op(op: Op, root: Op) -> tuple[Op, bool]:
if isinstance(op, CallOp):
if op.func is pd.read_csv:
new_op = make_read_op(op)

elif op.func is np.load:
new_op = make_read_op(op, "npy")

# input is a dataframe op
else:
Expand Down Expand Up @@ -401,7 +411,7 @@ def make_datetime_conversion_op(op: CallOp) -> DatetimeConversionOp:
return new_op


def make_read_op(op: CallOp) -> DataSourceOp:
def make_read_op(op: CallOp, format: str = "csv") -> DataSourceOp:
input_iter = iter(op.inputs)
# assume all inputs are ValueOps
assert all(isinstance(arg, ValueOp) or isinstance(arg, VariableOp) for arg in op.inputs), "All inputs must be ValueOps or VariableOps"
Expand All @@ -428,7 +438,7 @@ def make_read_op(op: CallOp) -> DataSourceOp:
kwargs[k] = actual_input_op.value
else:
kwargs[k] = v
new_op = DataSourceOp(file_path=args[0], _format="csv", read_args=args[1:], read_kwargs=kwargs, inputs=inputs, outputs=op.outputs)
new_op = DataSourceOp(file_path=args[0], _format=format, read_args=args[1:], read_kwargs=kwargs, inputs=inputs, outputs=op.outputs)
for in_ in inputs:
in_.replace_output(op, new_op)
return new_op
Expand Down
Loading
Loading