Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ Enhancements
parameters. A new fitted parameter `categorical_encoder_` is exposed to access the
fitted encoder.
:pr:`1001` by :user:`Guillaume Lemaitre <glemaitre>`.

- :class:`~imblearn.under_sampling.RandomUnderSampler` and
:class:`~imblearn.over_sampling.RandomOverSampler` (when `shrinkage is not
None`) now accept any data types and will not attempt any data conversion.
:pr:`1004` by :user:`Guillaume Lemaitre <glemaitre>`.
3 changes: 1 addition & 2 deletions examples/api/plot_sampling_strategy_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@
# resampling and the number of samples in the minority class, respectively.

# %%
import numpy as np

# select only 2 classes since the ratio make sense in this case
binary_mask = np.bitwise_or(y == 0, y == 2)
binary_mask = y.isin([0, 1])
binary_y = y[binary_mask]
binary_X = X[binary_mask]

Expand Down
7 changes: 5 additions & 2 deletions imblearn/datasets/tests/test_imbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ def test_make_imbalance_dict(iris, sampling_strategy, expected_counts):
],
)
def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts):
pytest.importorskip("pandas")
iris = load_iris(as_frame=True)
pd = pytest.importorskip("pandas")
iris = load_iris(as_frame=as_frame)
X, y = iris.data, iris.target
y = iris.target_names[iris.target]
if as_frame:
y = pd.Series(iris.target_names[iris.target], name="target")
X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy)
if as_frame:
assert hasattr(X_res, "loc")
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert Counter(y_res) == expected_counts
3 changes: 2 additions & 1 deletion imblearn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,11 +572,12 @@ def roughly_balanced_bagging(X, y, replace=False):

# Roughly Balanced Bagging
rbb = BalancedBaggingClassifier(
estimator=CountDecisionTreeClassifier(),
estimator=CountDecisionTreeClassifier(random_state=0),
n_estimators=2,
sampler=FunctionSampler(
func=roughly_balanced_bagging, kw_args={"replace": replace}
),
random_state=0,
)
rbb.fit(X, y)

Expand Down
15 changes: 7 additions & 8 deletions imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..utils import Substitution, check_target_type
from ..utils._docstring import _random_state_docstring
from ..utils._param_validation import Interval
from ..utils._validation import _check_X
from .base import BaseOverSampler


Expand Down Expand Up @@ -154,14 +155,9 @@ def __init__(

def _check_X_y(self, X, y):
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(
X,
y,
reset=True,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return X, y, binarize_y

def _fit_resample(self, X, y):
Expand Down Expand Up @@ -258,4 +254,7 @@ def _more_tags(self):
"X_types": ["2darray", "string", "sparse", "dataframe"],
"sample_indices": True,
"allow_nan": True,
"_xfail_checks": {
"check_complex_data": "Robust to this type of data.",
},
}
5 changes: 3 additions & 2 deletions imblearn/over_sampling/_smote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...utils import Substitution, check_neighbors_object, check_target_type
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
from ...utils._param_validation import HasMethods, Interval
from ...utils._validation import _check_X
from ...utils.fixes import _mode
from ..base import BaseOverSampler

Expand Down Expand Up @@ -559,9 +560,9 @@ def _check_X_y(self, X, y):
features.
"""
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
if not (hasattr(X, "__array__") or sparse.issparse(X)):
X = check_array(X, dtype=object)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return X, y, binarize_y

def _validate_estimator(self):
Expand Down
14 changes: 14 additions & 0 deletions imblearn/over_sampling/tests/test_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: MIT

from collections import Counter
from datetime import datetime

import numpy as np
import pytest
Expand Down Expand Up @@ -273,3 +274,16 @@ def test_random_over_sampler_strings(sampling_strategy):
random_state=0,
)
RandomOverSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)


def test_random_over_sampling_datetime():
"""Check that we don't convert input data and only sample from it."""
pd = pytest.importorskip("pandas")
X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
y = X["label"]
ros = RandomOverSampler(random_state=0)
X_res, y_res = ros.fit_resample(X, y)

pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1]))
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ...utils import Substitution, check_target_type
from ...utils._docstring import _random_state_docstring
from ...utils._validation import _check_X
from ..base import BaseUnderSampler


Expand Down Expand Up @@ -97,14 +98,9 @@ def __init__(

def _check_X_y(self, X, y):
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(
X,
y,
reset=True,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
)
X = _check_X(X)
self._check_n_features(X, reset=True)
self._check_feature_names(X, reset=True)
return X, y, binarize_y

def _fit_resample(self, X, y):
Expand Down Expand Up @@ -140,4 +136,7 @@ def _more_tags(self):
"X_types": ["2darray", "string", "sparse", "dataframe"],
"sample_indices": True,
"allow_nan": True,
"_xfail_checks": {
"check_complex_data": "Robust to this type of data.",
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: MIT

from collections import Counter
from datetime import datetime

import numpy as np
import pytest
Expand Down Expand Up @@ -148,3 +149,16 @@ def test_random_under_sampler_strings(sampling_strategy):
random_state=0,
)
RandomUnderSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)


def test_random_under_sampling_datetime():
"""Check that we don't convert input data and only sample from it."""
pd = pytest.importorskip("pandas")
X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
y = X["label"]
rus = RandomUnderSampler(random_state=0)
X_res, y_res = rus.fit_resample(X, y)

pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
pd.testing.assert_index_equal(X_res.index, y_res.index)
assert_array_equal(y_res.to_numpy(), np.array([0, 1]))
26 changes: 25 additions & 1 deletion imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import numpy as np
from sklearn.base import clone
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import column_or_1d
from sklearn.utils import check_array, column_or_1d
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _num_samples

from .fixes import _is_pandas_df

SAMPLING_KIND = (
"over-sampling",
Expand All @@ -35,6 +38,12 @@ def __init__(self, X, y):
def transform(self, X, y):
X = self._transfrom_one(X, self.x_props)
y = self._transfrom_one(y, self.y_props)
if self.x_props["type"].lower() == "dataframe" and self.y_props[
"type"
].lower() in {"series", "dataframe"}:
# We lost the y.index during resampling. We can safely use X.index to align
# them.
y.index = X.index
return X, y

def _gets_props(self, array):
Expand Down Expand Up @@ -607,3 +616,18 @@ def inner_f(*args, **kwargs):
return f(**kwargs)

return inner_f


def _check_X(X):
"""Check the shape of X and convert it if it is a list of list."""
n_samples = _num_samples(X)
if n_samples < 1:
raise ValueError(
f"Found array with {n_samples} sample(s) while a minimum of 1 is "
"required."
)
if _is_pandas_df(X):
return X
return check_array(
X, dtype=None, accept_sparse=["csr", "csc"], force_all_finite=False
)
16 changes: 16 additions & 0 deletions imblearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
which the fix is no longer needed.
"""
import functools
import sys

import numpy as np
import scipy
Expand Down Expand Up @@ -132,3 +133,18 @@ def _is_fitted(estimator, attributes=None, all_or_any=all):

else:
from sklearn.utils.validation import _is_fitted # type: ignore[no-redef]

try:
from sklearn.utils.validation import _is_pandas_df
except ImportError:

def _is_pandas_df(X):
"""Return True if the X is a pandas dataframe."""
if hasattr(X, "columns") and hasattr(X, "iloc"):
# Likely a pandas DataFrame, we explicitly check the type to confirm.
try:
pd = sys.modules["pandas"]
except KeyError:
return False
return isinstance(X, pd.DataFrame)
return False