diff --git a/janitor/functions/filter.py b/janitor/functions/filter.py index e3c007216..eee007dea 100644 --- a/janitor/functions/filter.py +++ b/janitor/functions/filter.py @@ -357,3 +357,58 @@ def filter_column_isin( if complement: return df[~criteria] return df[criteria] + + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Dict, List, Sequence, Tuple, Union + +import pandas as pd +from pandas_flavor import register_dataframe_method + +ColumnsArg = Union[str, List[str], Tuple[str, ...], Dict[str, Iterable]] + + +@register_dataframe_method +def filter_column_isin( + df: pd.DataFrame, + columns: ColumnsArg, + values: Iterable | None = None, + *, + complement: bool = False, +) -> pd.DataFrame: + """ + Supports: + 1) Single column (str) + 2) Multiple columns (list/tuple) + 3) Dictionary mapping (dict) + """ + if isinstance(columns, dict): + if values is not None: + raise ValueError("When `columns` is a dict, do not pass `values`.") + if not columns: + return df + mask = pd.Series(True, index=df.index) + for col, vals in columns.items(): + mask &= df[col].isin(vals) + + elif isinstance(columns, (list, tuple)): + if values is None: + raise ValueError( + "`values` must be provided when `columns` is a list/tuple." + ) + cols_seq: Sequence[str] = list(columns) + combos = [tuple(v) for v in values] + mask = df.set_index(cols_seq).index.isin(combos) + + else: + if values is None: + raise ValueError( + "`values` must be provided when `columns` is a string." + ) + mask = df[columns].isin(values) + + if complement: + mask = ~mask + return df.loc[mask]