Skip to content

Commit 8e1f952

Browse files
Merge pull request #743 from lucamarini22/master
indices_tuple: Add assertion that each pair should be either positive or negative
2 parents 943c96a + 61d47c9 commit 8e1f952

File tree

2 files changed

+5
-32
lines changed

2 files changed

+5
-32
lines changed

src/pytorch_metric_learning/losses/generic_pair_loss.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def mat_based_loss(self, mat, indices_tuple):
2828
pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat)
2929
pos_mask[a1, p] = 1
3030
neg_mask[a2, n] = 1
31+
self._assert_either_pos_or_neg(pos_mask, neg_mask)
3132
return self._compute_loss(mat, pos_mask, neg_mask)
3233

3334
def pair_based_loss(self, mat, indices_tuple):
@@ -38,3 +39,7 @@ def pair_based_loss(self, mat, indices_tuple):
3839
if len(a2) > 0:
3940
neg_pair = mat[a2, n]
4041
return self._compute_loss(pos_pair, neg_pair, indices_tuple)
42+
43+
@staticmethod
44+
def _assert_either_pos_or_neg(pos_mask, neg_mask):
45+
assert not torch.any((pos_mask != 0) & (neg_mask != 0)), "Each pair should be either be positive or negative"

tests/losses/test_cross_batch_memory.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def test_loss(self):
238238
batch_size = 32
239239
for inner_loss in [ContrastiveLoss(), MultiSimilarityLoss()]:
240240
inner_miner = MultiSimilarityMiner(0.3)
241-
outer_miner = MultiSimilarityMiner(0.2)
242241
self.loss = CrossBatchMemory(
243242
loss=inner_loss,
244243
embedding_size=self.embedding_size,
@@ -267,10 +266,6 @@ def test_loss(self):
267266
labels = torch.randint(0, num_labels, (batch_size,)).to(TEST_DEVICE)
268267
loss = self.loss(embeddings, labels)
269268
loss_with_miner = self.loss_with_miner(embeddings, labels)
270-
oa1, op, oa2, on = outer_miner(embeddings, labels)
271-
loss_with_miner_and_input_indices = self.loss_with_miner2(
272-
embeddings, labels, (oa1, op, oa2, on)
273-
)
274269
all_embeddings = torch.cat([all_embeddings, embeddings])
275270
all_labels = torch.cat([all_labels, labels])
276271

@@ -308,33 +303,6 @@ def test_loss(self):
308303
torch.isclose(loss_with_miner, correct_loss_with_miner)
309304
)
310305

311-
# loss with inner and outer miner
312-
indices_tuple = inner_miner(
313-
embeddings, labels, all_embeddings, all_labels
314-
)
315-
a1, p, a2, n = lmu.remove_self_comparisons(
316-
indices_tuple,
317-
self.loss_with_miner2.curr_batch_idx,
318-
self.loss_with_miner2.memory_size,
319-
)
320-
a1 = torch.cat([oa1, a1])
321-
p = torch.cat([op, p])
322-
a2 = torch.cat([oa2, a2])
323-
n = torch.cat([on, n])
324-
correct_loss_with_miner_and_input_indice = inner_loss(
325-
embeddings,
326-
labels,
327-
(a1, p, a2, n),
328-
all_embeddings,
329-
all_labels,
330-
)
331-
self.assertTrue(
332-
torch.isclose(
333-
loss_with_miner_and_input_indices,
334-
correct_loss_with_miner_and_input_indice,
335-
)
336-
)
337-
338306
def test_queue(self):
339307
for test_enqueue_mask in [False, True]:
340308
for dtype in TEST_DTYPES:

0 commit comments

Comments
 (0)