Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Proposal: Validation decorators for pipeline-style code #21512

Open
DeflateAwning opened this issue Feb 27, 2025 · 0 comments
Open

Feature Proposal: Validation decorators for pipeline-style code #21512

DeflateAwning opened this issue Feb 27, 2025 · 0 comments
Labels
enhancement New feature or an improvement of an existing feature

Comments

@DeflateAwning
Copy link
Contributor

Description

I use Polars in data pipelines, which are often composed of chains of functions which take in a df, modify it, and then return a slightly-modified version of that dataframe. My understanding is that this style is pretty common in complex data pipelines because it allows for easy testing and is easy to separate out code chunks.

I propose adding a submodule/namespace which contains decorator functions for performing validations on these types of functions. These validations would not be part of tests, but rather would be a "runtime validation"-type function.

Here is an example of the type of function I'm thinking of, with tests:

"""Decorators for validating pipeline-style transformations on Polars DataFrames."""

# pyright: strict

from collections.abc import Iterable, Callable
from typing import TypeVar
import functools

import polars as pl

from pn_data.helpers.polars.polars_validation_errors import PolarsColumnChangeCheckFailedError


F = TypeVar("F", bound=Callable[..., pl.DataFrame])


def assert_column_change(
    add: Iterable[str],
    drop: Iterable[str],
) -> Callable[[F], F]:
    expect_added = set(add)
    expect_removed = set(drop)

    if expect_added & expect_removed:
        raise ValueError("added_cols and removed_cols must not have overlapping column names")

    def decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(df: pl.DataFrame, *args: object, **kwargs: object) -> pl.DataFrame:
            orig_columns = set(df.columns)

            # Entry checks
            if expect_added & orig_columns:
                raise PolarsColumnChangeCheckFailedError(
                    "Unexpected pre-existing columns. Columns in 'added_cols' argument should "
                    "not already exist in the entry dataframe. Unexpected columns on entry: "
                    f"{expect_added & orig_columns}"
                )

            if expect_removed - orig_columns:
                raise PolarsColumnChangeCheckFailedError(
                    f"Missing input column. All columns in the 'removed_cols' argument must be "
                    "present in the entry dataframe."
                    f"Column(s) missing: {expect_removed - orig_columns}"
                )

            # Execute function
            result_df = func(df, *args, **kwargs)

            # Exit checks
            expected_columns = (orig_columns | expect_added) - expect_removed
            actual_columns = set(result_df.columns)

            if actual_columns != expected_columns:
                raise PolarsColumnChangeCheckFailedError(
                    f"Unexpected final columns. Extra columns: {actual_columns - expected_columns}, "
                    f"Missing columns: {expected_columns - actual_columns}"
                )

            return result_df

        return wrapper  # type: ignore # FIXME

    return decorator




class PolarsColumnChangeCheckFailedError(GenericPolarsValidationError):
    """Raised when a DataFrame has columns with too long of cells."""

    pass

Example usage/test cases:

# pyright: strict

import polars as pl
import pytest

from polars.polars_validation_errors import PolarsColumnChangeCheckFailedError
from polars import validation_decorators as pl_validation_decorators


@pl_validation_decorators.assert_column_change(add=["a"], drop=["b"])
def _transform_add_a_drop_b(df: pl.DataFrame) -> pl.DataFrame:
    return df.with_columns(
        a=pl.col(df.columns[0]),  # Add 'a' col as a copy of the first column in input.
    ).drop("b")  # Drop 'b' col.


def test_assert_column_change_WITH_normal_pass() -> None:
    df_input = pl.DataFrame(
        {
            "x": [1, 2, 3],
            "y": [4, 5, 6],
            "b": [7, 8, 9],
        }
    )

    df_output: pl.DataFrame = _transform_add_a_drop_b(df_input)
    assert df_output.columns == ["x", "y", "a"]


def test_assert_column_change_WITH_invalid_construction() -> None:
    with pytest.raises(ValueError, match="overlapping column names"):

        @pl_validation_decorators.assert_column_change(add=["a"], drop=["a"])
        def transform_add_a_drop_a(df: pl.DataFrame) -> pl.DataFrame:  # type: ignore reportUnusedFunction
            return df


def test_assert_column_change_WITH_bad_input_columns() -> None:
    with pytest.raises(
        PolarsColumnChangeCheckFailedError, match="Unexpected pre-existing columns"
    ):
        _transform_add_a_drop_b(
            pl.DataFrame(
                {
                    "x": [100, 200, 300],
                    "a": [1, 2, 3],  # Shouldn't have 'a' column in input.
                }
            )
        )

    with pytest.raises(PolarsColumnChangeCheckFailedError, match="Missing input column"):
        _transform_add_a_drop_b(
            pl.DataFrame(
                {
                    "x": [100, 200, 300],
                    # "b" col is missing.
                }
            )
        )


def test_assert_column_change_WITH_misbehaving_function() -> None:
    @pl_validation_decorators.assert_column_change(add=["a"], drop=["b"])
    def _transform_misbehave(df: pl.DataFrame) -> pl.DataFrame:
        return df.drop("b")  # Drop 'b' col.

    with pytest.raises(PolarsColumnChangeCheckFailedError, match="Unexpected final columns"):
        # The function is misbehaving by not adding the 'a' column.
        _transform_misbehave(
            pl.DataFrame(
                {
                    "x": [100, 200, 300],
                    "b": [1, 2, 3],
                }
            )
        )
@DeflateAwning DeflateAwning added the enhancement New feature or an improvement of an existing feature label Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or an improvement of an existing feature
Projects
None yet
Development

No branches or pull requests

1 participant