Skip to content
6 changes: 4 additions & 2 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,10 @@ def set_threshold(self, threshold):
The pairs classifier with the new threshold set.
"""
check_is_fitted(self, 'preprocessor_')

self.threshold_ = threshold
if not isinstance(threshold, (int, float)):
raise ValueError('Parameter threshold must be a real number. '
'Got {} instead.'.format(type(threshold)))
self.threshold_ = float(threshold) # int -> float
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to nitpick but can't you just call self.threshold_ = float(threshold) at the beginning and throw the custom warning if a ValueError is raised ? This would avoid calling isinstance; and would just let bool go through (converting to 0/1), which is fine.

Copy link
Contributor Author

@mvargas33 mvargas33 Oct 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it for this

try:
  self.threshold_ = float(threshold)
except Exception:
  raise ValueError('Parameter threshold must be a real number. '
                   'Got {} instead.'.format(type(threshold)))

I guess if a warning is shown instead, the code can fail somewhere else, because it won't stop the execution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And forgot to mention that isinstance(threshold, (int, float)) is permissive for bool. float(threshold) is also permissive.

return self

def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy',
Expand Down
25 changes: 25 additions & 0 deletions test/test_pairs_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ def test_set_threshold():
assert identity_pairs_classifier.threshold_ == 0.5


@pytest.mark.parametrize('value', ["ABC", None, [1, 2, 3], {'key': None},
(1, 2), set(),
np.array([[[0.], [1.]], [[1.], [3.]]]),
np.array([0.5]), np.array([[[0.5]]])])
def test_set_wrong_type_threshold(value):
"""
Test that `set_threshold` indeed sets the threshold
and cannot accept nothing but float or integers, but
being permissive with boolean True=1.0 and False=0.0
"""
model = IdentityPairsClassifier()
model.fit(np.array([[[0.], [1.]]]), np.array([1]))
msg = ('Parameter threshold must be a real number. '
'Got {} instead.'.format(type(value)))

with pytest.raises(ValueError) as e: # String
model.set_threshold(value)
assert str(e.value).startswith(msg)

model.set_threshold(1) # Integer
assert model.threshold_ == 1.0
model.set_threshold(0.1) # Float
assert model.threshold_ == 0.1


def test_f_beta_1_is_f_1():
# test that putting beta to 1 indeed finds the best threshold to optimize
# the f1_score
Expand Down