Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
53c7a8d
Created empty test units
Aug 21, 2020
183c03f
added ROSE empty class, modified __init__.py
Aug 21, 2020
0a3307b
implemented ROSE, still some failed test
Aug 21, 2020
886694f
PEP8 cleaning
Aug 21, 2020
07731c4
PEP8 linting
Aug 22, 2020
c0d7473
fixed linting errors.
Aug 26, 2020
013f7cc
updated documentation and bibliography
Aug 27, 2020
2b34f47
cleaned ROSE test
Aug 31, 2020
8d0e99e
added an exception for non binary datasets
Sep 10, 2020
8bbbd2e
multiclass oversampling
Sep 15, 2020
9ac2797
removed non-binary exception
Sep 15, 2020
b41b06a
removed unused import
Sep 15, 2020
d2dd6f4
minor fixes
Sep 15, 2020
b31a5c3
linting
Sep 15, 2020
d5ca24c
linting
Sep 15, 2020
b6e95aa
linting
Sep 15, 2020
6f7f8e1
linting
Sep 15, 2020
c391ec3
removed explicit pandas dataframe management
Sep 15, 2020
93ac868
added check_X_y() parsing
Sep 16, 2020
bdffda3
removed check_X_y test
Sep 16, 2020
cac7f0f
local test 1: shrink factors
Sep 16, 2020
e59c4d3
test
Sep 16, 2020
c24f29e
1
Sep 16, 2020
4c83cfe
1
Sep 16, 2020
f3fb23b
1
Sep 16, 2020
9653709
1
Sep 16, 2020
93f7f8d
1
Sep 16, 2020
97269f5
1
Sep 16, 2020
36658bd
1
Sep 16, 2020
f2fd72b
1
Sep 16, 2020
4fc8476
1
Sep 16, 2020
7399233
1
Sep 16, 2020
5962627
1
Sep 16, 2020
4a202f7
1
Sep 16, 2020
ace2785
1
Sep 16, 2020
ecae868
1
Sep 16, 2020
335cd04
1
Sep 16, 2020
90c7082
1
Sep 16, 2020
f2c7dc0
1
Sep 16, 2020
216672c
1
Sep 16, 2020
c466b3d
1
Sep 16, 2020
2b154bc
fixed sparse
Sep 17, 2020
a020e10
fixed all tests
Sep 17, 2020
d710f9e
test added
Sep 17, 2020
354cb47
linting, submitted version
Sep 17, 2020
2d6b12a
fixed test tolerance
Sep 17, 2020
94ecd4d
pep8 fix
Sep 17, 2020
048090c
tolerance adjustment
Sep 17, 2020
60777ad
documentation
Sep 17, 2020
56949ea
documentation
Sep 17, 2020
172496f
documentation
Sep 17, 2020
a049e38
fixed bug in ROSE sampling strategy parsing
Sep 17, 2020
517c79b
linting
Sep 17, 2020
17c9678
restored original docs
Sep 18, 2020
949a3be
updated docstrings
Sep 18, 2020
72cfe44
removed whitespaces
Sep 18, 2020
8a30871
added more documentation
Sep 18, 2020
c8434c0
added math formulations to docs
Sep 18, 2020
5a9768c
fixed revision issues
Sep 20, 2020
7e8031e
added missing math directives
Sep 20, 2020
9f7823e
dropped "see also" sections
Sep 20, 2020
9334fb2
fixed ROSE short description
Sep 20, 2020
b1e3b26
linting
Sep 20, 2020
c68d85c
removed double newline in rose.py docstring
Sep 20, 2020
b5d05c8
trying to fix See Also missing
Sep 20, 2020
f45710c
testing a fix for missing See Also section
Sep 20, 2020
9d33aa9
added just SMOTE to See Also section
Sep 20, 2020
73d0266
last minor typo fixes
Sep 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Below is a list of the methods currently implemented in this module.
5. SVM SMOTE - Support Vectors SMOTE [10]_
6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_
7. KMeans-SMOTE [17]_
8. ROSE - Random OverSampling Examples [19]_

* Over-sampling followed by under-sampling
1. SMOTE + Tomek links [12]_
Expand Down Expand Up @@ -210,4 +211,6 @@ References:

.. [17] : Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for Imbalanced Learning Based on K-Means and SMOTE"

.. [18] : Seiffert, C., Khoshgoftaar, T. M., Van Hulse, J., & Napolitano, A. "RUSBoost: A hybrid approach to alleviating class imbalance." IEEE Transactions on Systems, Man, and Cybernetics-Part A: Systems and Humans 40.1 (2010): 185-197.
.. [18] : Seiffert, C., Khoshgoftaar, T. M., Van Hulse, J., & Napolitano, A. "RUSBoost: A hybrid approach to alleviating class imbalance." IEEE Transactions on Systems, Man, and Cybernetics-Part A: Systems and Humans 40.1 (2010): 185-197.

.. [19] : Menardi, G., Torelli, N.: "Training and assessing classification rules with unbalanced data", Data Mining and Knowledge Discovery, 28, (2014): 92–122
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Prototype selection
over_sampling.SMOTE
over_sampling.SMOTENC
over_sampling.SVMSMOTE
over_sampling.ROSE


.. _combine_ref:
Expand Down
15 changes: 15 additions & 0 deletions doc/bibtex/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,18 @@ @article{smith2014instance
year={2014},
publisher={Springer}
}

@article{torelli2014rose,
author = {Menardi, Giovanna and Torelli, Nicola},
title={Training and assessing classification rules with imbalanced data},
author={Menardi G and Torelli N},
journal={Data Mining and Knowledge Discovery},
volume={28},
pages={92-122},
year={2014},
publisher={Springer},
issue = {1},
issn = {1573-756X},
url = {https://doi.org/10.1007/s10618-012-0295-5},
doi = {10.1007/s10618-012-0295-5}
}
17 changes: 17 additions & 0 deletions doc/over_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,23 @@ Therefore, it can be seen that the samples generated in the first and last
columns are belonging to the same categories originally presented without any
other extra interpolation.

.. _rose:

ROSE (Random Over-Sampling Examples)
------------------------------------

ROSE uses smoothed bootstrapping to draw artificial samples from the
feature space neighborhood around selected classes, using a multivariate
Gaussian kernel around randomly selected samples. First, random samples are
selected from original classes. Then the smoothing kernel distribution
is computed around the samples: :math:`\hat f(x|y=Y_i) = \sum_i^{n_j}
p_i Pr(x|x_i)=\sum_i^{n_j} \frac{1}{n_j} Pr(x|x_i)=\sum_i^{n_j}
\frac{1}{n_j} K_{H_j}(x|x_i)`.

Then new samples are drawn from the computed distribution.



Mathematical formulation
========================

Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/v0.7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ Enhancements
- Lazy import `keras` module when importing `imblearn.keras`
:pr:`719` by :user:`Guillaume Lemaitre <glemaitre>`.

- Added Random Over-Sampling Examples (ROSE) class.
:pr:`754` by :user:`Andrea Lorenzon <andrealorenzon>`.

Deprecation
...........

Expand Down
2 changes: 2 additions & 0 deletions imblearn/over_sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._smote import KMeansSMOTE
from ._smote import SVMSMOTE
from ._smote import SMOTENC
from ._rose import ROSE

__all__ = [
"ADASYN",
Expand All @@ -19,4 +20,5 @@
"BorderlineSMOTE",
"SVMSMOTE",
"SMOTENC",
"ROSE"
]
202 changes: 202 additions & 0 deletions imblearn/over_sampling/_rose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""Class to perform over-sampling using ROSE."""

import numpy as np
from scipy import sparse
from sklearn.utils import check_random_state
from .base import BaseOverSampler
from ..utils._validation import _deprecate_positional_args


class ROSE(BaseOverSampler):
"""Random Over-Sampling Examples (ROSE).

This object is the implementation of ROSE algorithm.
It generates new samples by a smoothed bootstrap approach,
taking a random subsample of original data and adding a
multivariate kernel density estimate :math:`f(x|Y_i)` around
them with a smoothing matrix :math:`H_j`, and finally sampling
from this distribution. A shrinking matrix can be provided, to
set the bandwidth of the gaussian kernel.

Read more in the :ref:`User Guide <rose>`.

Parameters
----------
sampling_strategy : float, str, dict or callable, default='auto'
Sampling information to resample the data set.

- When ``float``, it corresponds to the desired ratio of the number of
samples in the minority class over the number of samples in the
majority class after resampling. Therefore, the ratio is expressed as
:math:`\\alpha_{os} = N_{rm} / N_{M}` where :math:`N_{rm}` is the
number of samples in the minority class after resampling and
:math:`N_{M}` is the number of samples in the majority class.

.. warning::
``float`` is only available for **binary** classification. An
error is raised for multi-class classification.

- When ``str``, specify the class targeted by the resampling. The
number of samples in the different classes will be equalized.
Possible choices are:

``'minority'``: resample only the minority class;

``'not minority'``: resample all classes but the minority class;

``'not majority'``: resample all classes but the majority class;

``'all'``: resample all classes;

``'auto'``: equivalent to ``'not majority'``.

- When ``dict``, the keys correspond to the targeted classes. The
values correspond to the desired number of samples for each targeted
class.

- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.

shrink_factors : dict, default= 1 for every class
Dict of {classes: shrinkfactors} items, applied to
the gaussian kernels. It can be used to compress/dilate the kernel.

random_state : int, RandomState instance, default=None
Control the randomization of the algorithm.

- If int, ``random_state`` is the seed used by the random number
generator;
- If ``RandomState`` instance, random_state is the random number
generator;
- If ``None``, the random number generator is the ``RandomState``
instance used by ``np.random``.

n_jobs : int, default=None
Number of CPU cores used during the cross-validation loop.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See
`Glossary <https://scikit-learn.org/stable/glossary.html#term-n-jobs>`_
for more details.

See Also
--------
SMOTE : Over-sample using SMOTE.

Notes
-----

References
----------
.. [1] N. Lunardon, G. Menardi, N.Torelli, "ROSE: A Package for Binary
Imbalanced Learning," R Journal, 6(1), 2014.

.. [2] G Menardi, N. Torelli, "Training and assessing classification
rules with imbalanced data," Data Mining and Knowledge
Discovery, 28(1), pp.92-122, 2014.

Examples
--------

>>> from imblearn.over_sampling import ROSE
>>> from sklearn.datasets import make_classification
>>> from collections import Counter
>>> r = ROSE(shrink_factors={0:1, 1:0.5, 2:0.7})
>>> X, y = make_classification(n_classes=3, class_sep=2,
... weights=[0.1, 0.7, 0.2], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=2000, random_state=10)
>>> print('Original dataset shape %s' % Counter(y))
Original dataset shape Counter({1: 1400, 2: 400, 0: 200})
>>> X_res, y_res = r.fit_resample(X, y)
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({2: 1400, 1: 1400, 0: 1400})
"""

@_deprecate_positional_args
def __init__(self, *, sampling_strategy="auto", shrink_factors=None,
random_state=None, n_jobs=None):
super().__init__(sampling_strategy=sampling_strategy)
self.random_state = random_state
self.shrink_factors = shrink_factors
self.n_jobs = n_jobs

def _make_samples(self,
X,
class_indices,
n_class_samples,
h_shrink):
""" A support function that returns artificial samples constructed
from a random subsample of the data, by adding a multiviariate
gaussian kernel and sampling from this distribution. An optional
shrink factor can be included, to compress/dilate the kernel.

Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Observations from which the samples will be created.

class_indices : ndarray, shape (n_class_samples,)
The target class indices

n_class_samples : int
The total number of samples per class to generate

h_shrink : int
the shrink factor

Returns
-------
X_new : {ndarray, sparse matrix}, shape (n_samples, n_features)
Synthetically generated samples.

y_new : ndarray, shape (n_samples,)
Target values for synthetic samples.

"""

number_of_features = X.shape[1]
random_state = check_random_state(self.random_state)
samples_indices = random_state.choice(
class_indices, size=n_class_samples, replace=True)
minimize_amise = (4 / ((number_of_features + 2) * len(
class_indices))) ** (1 / (number_of_features + 4))
if sparse.issparse(X):
variances = np.diagflat(
np.std(X[class_indices, :].toarray(), axis=0, ddof=1))
else:
variances = np.diagflat(
np.std(X[class_indices, :], axis=0, ddof=1))
h_opt = h_shrink * minimize_amise * variances
randoms = random_state.standard_normal(size=(n_class_samples,
number_of_features))
Xrose = np.matmul(randoms, h_opt) + X[samples_indices, :]
if sparse.issparse(X):
return sparse.csr_matrix(Xrose)
return Xrose

def _fit_resample(self, X, y):

X_resampled = X.copy()
y_resampled = y.copy()

if self.shrink_factors is None:
self.shrink_factors = {
key: 1 for key in self.sampling_strategy_.keys()}

for class_sample, n_samples in self.sampling_strategy_.items():
class_indices = np.flatnonzero(y == class_sample)
n_class_samples = n_samples
X_new = self._make_samples(X,
class_indices,
n_samples,
self.shrink_factors[class_sample])
y_new = np.array([class_sample] * n_class_samples)

if sparse.issparse(X_new):
X_resampled = sparse.vstack([X_resampled, X_new])
else:
X_resampled = np.concatenate((X_resampled, X_new))

y_resampled = np.hstack((y_resampled, y_new))

return X_resampled.astype(X.dtype), y_resampled.astype(y.dtype)
Loading