@@ -238,7 +238,6 @@ def test_loss(self):
238238 batch_size = 32
239239 for inner_loss in [ContrastiveLoss (), MultiSimilarityLoss ()]:
240240 inner_miner = MultiSimilarityMiner (0.3 )
241- outer_miner = MultiSimilarityMiner (0.2 )
242241 self .loss = CrossBatchMemory (
243242 loss = inner_loss ,
244243 embedding_size = self .embedding_size ,
@@ -267,10 +266,6 @@ def test_loss(self):
267266 labels = torch .randint (0 , num_labels , (batch_size ,)).to (TEST_DEVICE )
268267 loss = self .loss (embeddings , labels )
269268 loss_with_miner = self .loss_with_miner (embeddings , labels )
270- oa1 , op , oa2 , on = outer_miner (embeddings , labels )
271- loss_with_miner_and_input_indices = self .loss_with_miner2 (
272- embeddings , labels , (oa1 , op , oa2 , on )
273- )
274269 all_embeddings = torch .cat ([all_embeddings , embeddings ])
275270 all_labels = torch .cat ([all_labels , labels ])
276271
@@ -308,33 +303,6 @@ def test_loss(self):
308303 torch .isclose (loss_with_miner , correct_loss_with_miner )
309304 )
310305
311- # loss with inner and outer miner
312- indices_tuple = inner_miner (
313- embeddings , labels , all_embeddings , all_labels
314- )
315- a1 , p , a2 , n = lmu .remove_self_comparisons (
316- indices_tuple ,
317- self .loss_with_miner2 .curr_batch_idx ,
318- self .loss_with_miner2 .memory_size ,
319- )
320- a1 = torch .cat ([oa1 , a1 ])
321- p = torch .cat ([op , p ])
322- a2 = torch .cat ([oa2 , a2 ])
323- n = torch .cat ([on , n ])
324- correct_loss_with_miner_and_input_indice = inner_loss (
325- embeddings ,
326- labels ,
327- (a1 , p , a2 , n ),
328- all_embeddings ,
329- all_labels ,
330- )
331- self .assertTrue (
332- torch .isclose (
333- loss_with_miner_and_input_indices ,
334- correct_loss_with_miner_and_input_indice ,
335- )
336- )
337-
338306 def test_queue (self ):
339307 for test_enqueue_mask in [False , True ]:
340308 for dtype in TEST_DTYPES :
0 commit comments