Skip to content

Commit ff11e5a

Browse files
authored
Fix rare edge case with extremely inbalanced data (#1244)
* Fix rare edge case with extremely inbalanced data For dataset 360112 Auto-sklearn would fail because the data would first be sub-sampled and then contain some classes only once. In the internal splitting, the StratifiedShuffleSplit would not be able to split the dataset into train and valid, and would resort to only a ShuffleSplit. This could put the single sample for a class into the test set. At predict time we would then miss one class. This commit creates two new splitters which move a sample from the test split to the training split if a class does not exist in the train split. * fix unit test
1 parent 63808ef commit ff11e5a

File tree

3 files changed

+201
-15
lines changed

3 files changed

+201
-15
lines changed

autosklearn/evaluation/splitter.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import numpy as np
2+
3+
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
4+
from sklearn.model_selection._split import _validate_shuffle_split
5+
from sklearn.utils import indexable, check_random_state
6+
from sklearn.utils import _approximate_mode
7+
from sklearn.utils.validation import _num_samples, column_or_1d
8+
from sklearn.utils.validation import check_array
9+
from sklearn.utils.multiclass import type_of_target
10+
11+
12+
class CustomStratifiedShuffleSplit(StratifiedShuffleSplit):
13+
"""Stratified ShuffleSplit cross-validator that deals with classes with too few samples
14+
"""
15+
16+
def _iter_indices(self, X, y, groups=None): # type: ignore
17+
n_samples = _num_samples(X)
18+
y = check_array(y, ensure_2d=False, dtype=None)
19+
n_train, n_test = _validate_shuffle_split(
20+
n_samples, self.test_size, self.train_size,
21+
default_test_size=self._default_test_size)
22+
23+
if y.ndim == 2:
24+
# for multi-label y, map each distinct row to a string repr
25+
# using join because str(row) uses an ellipsis if len(row) > 1000
26+
y = np.array([' '.join(row.astype('str')) for row in y])
27+
28+
classes, y_indices = np.unique(y, return_inverse=True)
29+
n_classes = classes.shape[0]
30+
31+
class_counts = np.bincount(y_indices)
32+
# print(class_counts)
33+
34+
if n_train < n_classes:
35+
raise ValueError('The train_size = %d should be greater or '
36+
'equal to the number of classes = %d' %
37+
(n_train, n_classes))
38+
if n_test < n_classes:
39+
raise ValueError('The test_size = %d should be greater or '
40+
'equal to the number of classes = %d' %
41+
(n_test, n_classes))
42+
43+
# Find the sorted list of instances for each class:
44+
# (np.unique above performs a sort, so code is O(n logn) already)
45+
class_indices = np.split(np.argsort(y_indices, kind='mergesort'),
46+
np.cumsum(class_counts)[:-1])
47+
48+
rng = check_random_state(self.random_state)
49+
50+
for _ in range(self.n_splits):
51+
# if there are ties in the class-counts, we want
52+
# to make sure to break them anew in each iteration
53+
n_i = _approximate_mode(class_counts, n_train, rng)
54+
class_counts_remaining = class_counts - n_i
55+
t_i = _approximate_mode(class_counts_remaining, n_test, rng)
56+
train = []
57+
test = []
58+
59+
for i in range(n_classes):
60+
# print("Before", i, class_counts[i], n_i[i], t_i[i])
61+
permutation = rng.permutation(class_counts[i])
62+
perm_indices_class_i = class_indices[i].take(permutation,
63+
mode='clip')
64+
if n_i[i] == 0:
65+
n_i[i] = 1
66+
t_i[i] = t_i[i] - 1
67+
68+
# print("After", i, class_counts[i], n_i[i], t_i[i])
69+
train.extend(perm_indices_class_i[:n_i[i]])
70+
test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])
71+
72+
train = rng.permutation(train)
73+
test = rng.permutation(test)
74+
75+
yield train, test
76+
77+
78+
class CustomStratifiedKFold(StratifiedKFold):
79+
"""Stratified K-Folds cross-validator that ensures that there is always at least
80+
1 sample per class in the training set.
81+
"""
82+
83+
def _make_test_folds(self, X, y=None): # type: ignore
84+
rng = check_random_state(self.random_state)
85+
y = np.asarray(y)
86+
type_of_target_y = type_of_target(y)
87+
allowed_target_types = ('binary', 'multiclass')
88+
if type_of_target_y not in allowed_target_types:
89+
raise ValueError(
90+
'Supported target types are: {}. Got {!r} instead.'.format(
91+
allowed_target_types, type_of_target_y))
92+
93+
y = column_or_1d(y)
94+
95+
_, y_idx, y_inv = np.unique(y, return_index=True, return_inverse=True)
96+
# y_inv encodes y according to lexicographic order. We invert y_idx to
97+
# map the classes so that they are encoded by order of appearance:
98+
# 0 represents the first label appearing in y, 1 the second, etc.
99+
_, class_perm = np.unique(y_idx, return_inverse=True)
100+
y_encoded = class_perm[y_inv]
101+
102+
n_classes = len(y_idx)
103+
104+
# Determine the optimal number of samples from each class in each fold,
105+
# using round robin over the sorted y. (This can be done direct from
106+
# counts, but that code is unreadable.)
107+
y_order = np.sort(y_encoded)
108+
allocation = np.asarray(
109+
[np.bincount(y_order[i::self.n_splits], minlength=n_classes)
110+
for i in range(self.n_splits)])
111+
112+
# To maintain the data order dependencies as best as possible within
113+
# the stratification constraint, we assign samples from each class in
114+
# blocks (and then mess that up when shuffle=True).
115+
test_folds = np.empty(len(y), dtype='i')
116+
for k in range(n_classes):
117+
# since the kth column of allocation stores the number of samples
118+
# of class k in each test set, this generates blocks of fold
119+
# indices corresponding to the allocation for class k.
120+
folds_for_class = np.arange(self.n_splits).repeat(allocation[:, k])
121+
if self.shuffle:
122+
rng.shuffle(folds_for_class)
123+
test_folds[y_encoded == k] = folds_for_class
124+
return test_folds
125+
126+
def split(self, X, y=None, groups=None): # type: ignore
127+
128+
X, y, groups = indexable(X, y, groups)
129+
n_samples = _num_samples(X)
130+
if self.n_splits > n_samples:
131+
raise ValueError(
132+
("Cannot have number of splits n_splits={0} greater"
133+
" than the number of samples: n_samples={1}.")
134+
.format(self.n_splits, n_samples))
135+
136+
for train, test in super().split(X, y, groups):
137+
# print(len(np.unique(y)), len(np.unique(y[train])), len(np.unique(y[test])))
138+
all_classes = np.unique(y)
139+
train_classes = np.unique(y[train])
140+
train = list(train)
141+
test = list(test)
142+
missing_classes = set(all_classes) - set(train_classes)
143+
if len(missing_classes) > 0:
144+
# print(missing_classes)
145+
for diff in missing_classes:
146+
# print(len(train), len(test))
147+
to_move = np.where(y[test] == diff)[0][0]
148+
# print(y[test][to_move])
149+
train = train + [test[to_move]]
150+
del test[to_move]
151+
# print(len(train), len(test))
152+
train = np.array(train, dtype=int)
153+
test = np.array(test, dtype=int)
154+
# print(
155+
# len(np.unique(y)),
156+
# len(np.unique(y[train])),
157+
# len(np.unique(y[test])),
158+
# len(train), len(test),
159+
# )
160+
161+
yield train, test

autosklearn/evaluation/train_evaluator.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import multiprocessing
3+
import warnings
34
from typing import Any, Dict, List, Optional, Tuple, Union, cast
45

56
import copy
@@ -21,6 +22,7 @@
2122
TYPE_ADDITIONAL_INFO,
2223
_fit_and_suppress_warnings,
2324
)
25+
from autosklearn.evaluation.splitter import CustomStratifiedShuffleSplit, CustomStratifiedKFold
2426
from autosklearn.data.abstract_data_manager import AbstractDataManager
2527
from autosklearn.constants import (
2628
CLASSIFICATION_TASKS,
@@ -1037,15 +1039,20 @@ def get_splitter(self, D: AbstractDataManager) -> Union[BaseCrossValidator, _Rep
10371039

10381040
if shuffle:
10391041
try:
1040-
cv = StratifiedShuffleSplit(n_splits=1,
1041-
test_size=test_size,
1042-
random_state=1)
1042+
cv = StratifiedShuffleSplit(
1043+
n_splits=1,
1044+
test_size=test_size,
1045+
random_state=1,
1046+
)
10431047
test_cv = copy.deepcopy(cv)
10441048
next(test_cv.split(y, y))
10451049
except ValueError as e:
10461050
if 'The least populated class in y has only' in e.args[0]:
1047-
cv = ShuffleSplit(n_splits=1, test_size=test_size,
1048-
random_state=1)
1051+
cv = CustomStratifiedShuffleSplit(
1052+
n_splits=1,
1053+
test_size=test_size,
1054+
random_state=1,
1055+
)
10491056
else:
10501057
raise e
10511058
else:
@@ -1057,9 +1064,26 @@ def get_splitter(self, D: AbstractDataManager) -> Union[BaseCrossValidator, _Rep
10571064
elif self.resampling_strategy in ['cv', 'cv-iterative-fit', 'partial-cv',
10581065
'partial-cv-iterative-fit']:
10591066
if shuffle:
1060-
cv = StratifiedKFold(
1061-
n_splits=self.resampling_strategy_args['folds'],
1062-
shuffle=shuffle, random_state=1)
1067+
try:
1068+
with warnings.catch_warnings():
1069+
warnings.simplefilter('error')
1070+
cv = StratifiedKFold(
1071+
n_splits=self.resampling_strategy_args['folds'],
1072+
shuffle=shuffle,
1073+
random_state=1,
1074+
)
1075+
test_cv = copy.deepcopy(cv)
1076+
next(test_cv.split(y, y))
1077+
except UserWarning as e:
1078+
print(e)
1079+
if 'The least populated class in y has only' in e.args[0]:
1080+
cv = CustomStratifiedKFold(
1081+
n_splits=self.resampling_strategy_args['folds'],
1082+
shuffle=shuffle,
1083+
random_state=1,
1084+
)
1085+
else:
1086+
raise e
10631087
else:
10641088
cv = KFold(n_splits=self.resampling_strategy_args['folds'],
10651089
shuffle=shuffle)

test/test_evaluation/test_train_evaluator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sklearn.model_selection
2020
from smac.tae import StatusType, TAEAbortException
2121

22+
import autosklearn.evaluation.splitter
2223
from autosklearn.data.abstract_data_manager import AbstractDataManager
2324
from autosklearn.evaluation.util import read_queue
2425
from autosklearn.evaluation.train_evaluator import TrainEvaluator, \
@@ -1080,17 +1081,17 @@ def test_get_splitter(self, te_mock):
10801081
self.assertIsInstance(cv,
10811082
sklearn.model_selection.PredefinedSplit)
10821083

1083-
# holdout, binary classification, fallback to shuffle split
1084+
# holdout, binary classification, fallback to custom shuffle split
10841085
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1, 2])
10851086
evaluator = TrainEvaluator()
10861087
evaluator.resampling_strategy = 'holdout'
10871088
evaluator.resampling_strategy_args = {}
10881089
cv = evaluator.get_splitter(D)
10891090
self.assertIsInstance(cv,
1090-
sklearn.model_selection._split.ShuffleSplit)
1091+
autosklearn.evaluation.splitter.CustomStratifiedShuffleSplit)
10911092

10921093
# cv, binary classification
1093-
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1])
1094+
D.data['Y_train'] = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
10941095
evaluator = TrainEvaluator()
10951096
evaluator.resampling_strategy = 'cv'
10961097
evaluator.resampling_strategy_args = {'folds': 5}
@@ -1099,7 +1100,7 @@ def test_get_splitter(self, te_mock):
10991100
sklearn.model_selection._split.StratifiedKFold)
11001101

11011102
# cv, binary classification, shuffle is True
1102-
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1])
1103+
D.data['Y_train'] = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
11031104
evaluator = TrainEvaluator()
11041105
evaluator.resampling_strategy = 'cv'
11051106
evaluator.resampling_strategy_args = {'folds': 5}
@@ -1118,14 +1119,14 @@ def test_get_splitter(self, te_mock):
11181119
sklearn.model_selection._split.KFold)
11191120
self.assertFalse(cv.shuffle)
11201121

1121-
# cv, binary classification, no fallback anticipated
1122-
D.data['Y_train'] = np.array([0, 0, 0, 1, 1, 1, 2])
1122+
# cv, binary classification, fallback to custom splitter
1123+
D.data['Y_train'] = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2])
11231124
evaluator = TrainEvaluator()
11241125
evaluator.resampling_strategy = 'cv'
11251126
evaluator.resampling_strategy_args = {'folds': 5}
11261127
cv = evaluator.get_splitter(D)
11271128
self.assertIsInstance(cv,
1128-
sklearn.model_selection._split.StratifiedKFold)
1129+
autosklearn.evaluation.splitter.CustomStratifiedKFold)
11291130

11301131
# regression, shuffle split
11311132
D.data['Y_train'] = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5])

0 commit comments

Comments
 (0)