-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH make Random*Sampler accept dask array and dataframe #777
base: master
Are you sure you want to change the base?
Changes from 19 commits
95247e6
ea30287
0766964
d9edb9a
4960724
2152429
e5ce7a6
b537a20
f781be0
fb3d6a4
b7d9f3b
d26da3c
c065808
f2d0ec0
20ba934
0941a5e
7aae9d9
00c0a26
8bfa040
d4aabf8
58acdf2
e54c772
f2a572f
36a0aa3
c7bdc74
f095221
20b44c6
4cd9116
a6e975b
32eda46
6c592ff
456c3eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
_REGISTERED_DASK_CONTAINER = [] | ||
|
||
try: | ||
from dask import array, dataframe | ||
_REGISTERED_DASK_CONTAINER += [ | ||
array.Array, dataframe.Series, dataframe.DataFrame, | ||
] | ||
except ImportError: | ||
pass | ||
|
||
|
||
def is_dask_container(container): | ||
return isinstance(container, tuple(_REGISTERED_DASK_CONTAINER)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
dask = pytest.importorskip("dask") | ||
from dask import array | ||
|
||
from imblearn.dask.utils import is_multilabel | ||
from imblearn.dask.utils import type_of_target | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"y, expected_result", | ||
[ | ||
(array.from_array(np.array([0, 1, 0, 1])), False), | ||
(array.from_array(np.array([[1, 0], [0, 0]])), True), | ||
(array.from_array(np.array([[1], [0], [0]])), False), | ||
(array.from_array(np.array([[1, 0, 0]])), True), | ||
] | ||
) | ||
def test_is_multilabel(y, expected_result): | ||
assert is_multilabel(y) is expected_result | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"y, expected_type_of_target", | ||
[ | ||
(array.from_array(np.array([[1, 0], [0, 0]])), "multilabel-indicator"), | ||
(array.from_array(np.array([[1, 0, 0]])), "multilabel-indicator"), | ||
(array.from_array(np.array([[[1, 2]]])), "unknown"), | ||
(array.from_array(np.array([[]])), "unknown"), | ||
(array.from_array(np.array([.1, .2, 3])), "continuous"), | ||
(array.from_array(np.array([[.1, .2, 3]])), "continuous-multioutput"), | ||
(array.from_array(np.array([[1., .2]])), "continuous-multioutput"), | ||
(array.from_array(np.array([1, 2])), "binary"), | ||
(array.from_array(np.array(["a", "b"])), "binary"), | ||
] | ||
) | ||
def test_type_of_target(y, expected_type_of_target): | ||
target_type = type_of_target(y) | ||
assert target_type == expected_type_of_target |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import warnings | ||
|
||
import numpy as np | ||
from sklearn.exceptions import DataConversionWarning | ||
from sklearn.utils.multiclass import _is_integral_float | ||
|
||
|
||
def is_multilabel(y): | ||
if not (y.ndim == 2 and y.shape[1] > 1): | ||
return False | ||
|
||
if hasattr(y, "unique"): | ||
labels = np.asarray(y.unique()) | ||
else: | ||
labels = np.unique(y).compute() | ||
Comment on lines
+12
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've struggled with this check in dask-ml. Depending on where it's called, it's potentially very expensive (you might be loading a ton of data just to check if it's multi-label, and then loading it again to to the training). Whenever possible, it's helpful to provide an option to skip this check by having the user specify it when creating the estimator, or in a keyword to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about it. Do you think that having a context manager outside would make sense: with set_config(avoid_check=True):
# some imblearn/scikit-learn/dask code Thought, we might get into trouble with issues related to scikit-learn/scikit-learn#18736 It might just be easier to have an optional class parameter that applies only for dask arrays. |
||
|
||
return len(labels) < 3 and ( | ||
y.dtype.kind in 'biu' or _is_integral_float(labels) | ||
) | ||
|
||
|
||
def type_of_target(y): | ||
if is_multilabel(y): | ||
return 'multilabel-indicator' | ||
|
||
if y.ndim > 2: | ||
return 'unknown' | ||
|
||
if y.ndim == 2 and y.shape[1] == 0: | ||
return 'unknown' # [[]] | ||
|
||
if y.ndim == 2 and y.shape[1] > 1: | ||
# [[1, 2], [1, 2]] | ||
suffix = "-multioutput" | ||
else: | ||
# [1, 2, 3] or [[1], [2], [3]] | ||
suffix = "" | ||
|
||
# check float and contains non-integer float values | ||
if y.dtype.kind == 'f' and np.any(y != y.astype(int)): | ||
# [.1, .2, 3] or [[.1, .2, 3]] or [[1., .2]] and not [1., 2., 3.] | ||
# NOTE: we don't check for infinite values | ||
return 'continuous' + suffix | ||
|
||
if hasattr(y, "unique"): | ||
labels = np.asarray(y.unique()) | ||
else: | ||
labels = np.unique(y).compute() | ||
if (len((labels)) > 2) or (y.ndim >= 2 and len(y[0]) > 1): | ||
# [1, 2, 3] or [[1., 2., 3]] or [[1, 2]] | ||
return 'multiclass' + suffix | ||
# [1, 2] or [["a"], ["b"]] | ||
return 'binary' | ||
|
||
|
||
def column_or_1d(y, *, warn=False): | ||
shape = y.shape | ||
if len(shape) == 1: | ||
return y.ravel() | ||
if len(shape) == 2 and shape[1] == 1: | ||
if warn: | ||
warnings.warn( | ||
"A column-vector y was passed when a 1d array was expected. " | ||
"Please change the shape of y to (n_samples, ), for example " | ||
"using ravel().", DataConversionWarning, stacklevel=2 | ||
) | ||
return y.ravel() | ||
|
||
raise ValueError( | ||
f"y should be a 1d array. Got an array of shape {shape} instead." | ||
) | ||
|
||
|
||
def check_classification_targets(y): | ||
y_type = type_of_target(y) | ||
if y_type not in ['binary', 'multiclass', 'multiclass-multioutput', | ||
'multilabel-indicator', 'multilabel-sequences']: | ||
raise ValueError("Unknown label type: %r" % y_type) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from sklearn.utils import _safe_indexing | ||
|
||
from ..base import BaseUnderSampler | ||
from ...dask._support import is_dask_container | ||
from ...utils import check_target_type | ||
from ...utils import Substitution | ||
from ...utils._docstring import _random_state_docstring | ||
|
@@ -80,44 +81,73 @@ def __init__( | |
self.replacement = replacement | ||
|
||
def _check_X_y(self, X, y): | ||
y, binarize_y = check_target_type(y, indicate_one_vs_all=True) | ||
X, y = self._validate_data( | ||
X, y, reset=True, accept_sparse=["csr", "csc"], dtype=None, | ||
force_all_finite=False, | ||
if is_dask_container(y) and hasattr(y, "to_dask_array"): | ||
y = y.to_dask_array() | ||
y.compute_chunk_sizes() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Dask-ML we (@stsievert I think? maybe me?) prefer to have the user do this: https://github.com/dask/dask-ml/blob/7e11ce1505a485104e02d49a3620c8264e63e12e/dask_ml/utils.py#L166-L173. If you're just fitting the one estimator then this is probably equivalent. If you're doing something like a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something that I was unsure of, here. If I recall, the issue was that I could not have called However, if we assume that the checks are too expensive to be done in a distributive setting, we don't need to call the check below and we can directly pass the Series and handle it during the resampling. So, we have fewer safeguards but at least it is more performant which is something you probably want in a distrubted setting |
||
y, binarize_y, self._uniques = check_target_type( | ||
y, | ||
indicate_one_vs_all=True, | ||
return_unique=True, | ||
) | ||
if not any([is_dask_container(arr) for arr in (X, y)]): | ||
X, y = self._validate_data( | ||
X, | ||
y, | ||
reset=True, | ||
accept_sparse=["csr", "csc"], | ||
dtype=None, | ||
force_all_finite=False, | ||
) | ||
elif is_dask_container(X) and hasattr(X, "to_dask_array"): | ||
X = X.to_dask_array() | ||
X.compute_chunk_sizes() | ||
return X, y, binarize_y | ||
|
||
@staticmethod | ||
def _find_target_class_indices(y, target_class): | ||
target_class_indices = np.flatnonzero(y == target_class) | ||
if is_dask_container(y): | ||
return target_class_indices.compute() | ||
return target_class_indices | ||
|
||
def _fit_resample(self, X, y): | ||
random_state = check_random_state(self.random_state) | ||
|
||
idx_under = np.empty((0,), dtype=int) | ||
idx_under = [] | ||
|
||
for target_class in np.unique(y): | ||
for target_class in self._uniques: | ||
target_class_indices = self._find_target_class_indices( | ||
y, target_class | ||
) | ||
if target_class in self.sampling_strategy_.keys(): | ||
n_samples = self.sampling_strategy_[target_class] | ||
index_target_class = random_state.choice( | ||
range(np.count_nonzero(y == target_class)), | ||
target_class_indices.size, | ||
size=n_samples, | ||
replace=self.replacement, | ||
) | ||
else: | ||
index_target_class = slice(None) | ||
|
||
idx_under = np.concatenate( | ||
( | ||
idx_under, | ||
np.flatnonzero(y == target_class)[index_target_class], | ||
), | ||
axis=0, | ||
) | ||
selected_indices = target_class_indices[index_target_class] | ||
idx_under.append(selected_indices) | ||
|
||
self.sample_indices_ = idx_under | ||
self.sample_indices_ = np.hstack(idx_under) | ||
self.sample_indices_.sort() | ||
|
||
return _safe_indexing(X, idx_under), _safe_indexing(y, idx_under) | ||
return ( | ||
_safe_indexing(X, self.sample_indices_), | ||
_safe_indexing(y, self.sample_indices_) | ||
) | ||
|
||
def _more_tags(self): | ||
return { | ||
"X_types": ["2darray", "string"], | ||
"X_types": [ | ||
"2darray", | ||
"string", | ||
"dask-array", | ||
"dask-dataframe" | ||
], | ||
"sample_indices": True, | ||
"allow_nan": True, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
People can have just
dask[array]
installed (not dataframe) so it's possible to have thearray
import succeed, but thedataframe
import fail. So if you want to support that case those would need to be in separate try / except blocks.Maybe you instead want
from dask import is_dask_collection
? That's a bit broader though (it also covers anything implementing dask's collection interface likedask.Bag
,xarray.DataArray
andxarray.Dataset
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems what I wanted :)