Skip to content

Commit 943c96a

Browse files
Merge pull request #721 from ir2718/smooth_ap
[`feat`] Implementing SmoothAP loss
2 parents d5cfc82 + 44bbb39 commit 943c96a

File tree

8 files changed

+327
-1
lines changed

8 files changed

+327
-1
lines changed
60.2 KB
Loading
19.7 KB
Loading
28.7 KB
Loading

docs/losses.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,37 @@ losses.SignalToNoiseRatioContrastiveLoss(pos_margin=0, neg_margin=1, **kwargs):
10871087
* **pos_loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
10881088
* **neg_loss**: The loss per negative pair in the batch. Reduction type is ```"neg_pair"```.
10891089

1090+
## SmoothAPLoss
1091+
[Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163){target=_blank}
1092+
1093+
```python
1094+
losses.SmoothAPLoss(
1095+
margin=0.01,
1096+
**kwargs
1097+
)
1098+
```
1099+
1100+
**Equations**:
1101+
1102+
![smooth_ap_loss_equation1](imgs/smooth_ap_sigmoid_equation.png){: style="height:100px"}
1103+
![smooth_ap_loss_equation2](imgs/smooth_ap_approx_equation.png){: style="height:100px"}
1104+
![smooth_ap_loss_equation3](imgs/smooth_ap_loss_equation.png){: style="height:100px"}
1105+
1106+
1107+
**Parameters**:
1108+
1109+
* **temperature**: The desired temperature for scaling the sigmoid function. This is denoted by $\tau$ in the first and second equations.
1110+
1111+
1112+
**Other info**:
1113+
1114+
* The loss requires the same number of number of elements for each class in the batch labels. An example of valid labels is: `[1, 1, 2, 2, 3, 3]`. An example of invalid labels is `[1, 1, 1, 2, 2, 3, 3]` because there are `3` elements with the value `1`. This can be achieved by using `samplers.MPerClassSampler` and setting the `batch_size` and `m` hyperparameters.
1115+
1116+
**Default distance**:
1117+
1118+
- [```CosineSimilarity()```](distances.md#cosinesimilarity)
1119+
- This is the only compatible distance.
1120+
10901121
## SoftTripleLoss
10911122
[SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf){target=_blank}
10921123
```python

src/pytorch_metric_learning/distances/dot_product_similarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, **kwargs):
99
assert self.is_inverted
1010

1111
def compute_mat(self, query_emb, ref_emb):
12-
return torch.matmul(query_emb, ref_emb.t())
12+
return torch.matmul(query_emb, ref_emb.transpose(-1, -2))
1313

1414
def pairwise_distance(self, query_emb, ref_emb):
1515
return torch.sum(query_emb * ref_emb, dim=1)

src/pytorch_metric_learning/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .ranked_list_loss import RankedListLoss
3131
from .self_supervised_loss import SelfSupervisedLoss
3232
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
33+
from .smooth_ap import SmoothAPLoss
3334
from .soft_triple_loss import SoftTripleLoss
3435
from .sphereface_loss import SphereFaceLoss
3536
from .subcenter_arcface_loss import SubCenterArcFaceLoss
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
from ..distances import CosineSimilarity
5+
from ..utils import common_functions as c_f
6+
from ..utils import loss_and_miner_utils as lmu
7+
from .base_metric_loss_function import BaseMetricLossFunction
8+
9+
10+
class SmoothAPLoss(BaseMetricLossFunction):
11+
"""
12+
Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163
13+
"""
14+
15+
def __init__(self, temperature=0.01, **kwargs):
16+
super().__init__(**kwargs)
17+
c_f.assert_distance_type(self, CosineSimilarity)
18+
self.temperature = temperature
19+
20+
def get_default_distance(self):
21+
return CosineSimilarity()
22+
23+
# Implementation is based on the original repository:
24+
# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87
25+
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
26+
# The loss expects labels such that there is the same number of elements for each class
27+
# The number of classes is not important, nor their order, but the number of elements must be the same, eg.
28+
#
29+
# The following label is valid:
30+
# [ A,A,A, B,B,B, C,C,C ]
31+
# The following label is NOT valid:
32+
# [ B,B,B A,A,A,A, C,C,C ]
33+
#
34+
c_f.labels_required(labels)
35+
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
36+
37+
counts = torch.bincount(labels)
38+
nonzero_indices = torch.nonzero(counts, as_tuple=True)[0]
39+
nonzero_counts = counts[nonzero_indices]
40+
if nonzero_counts.unique().size(0) != 1:
41+
raise ValueError(
42+
"All classes must have the same number of elements in the labels.\n"
43+
"The given labels have the following number of elements: {}.\n"
44+
"You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format(
45+
nonzero_counts.cpu().tolist()
46+
)
47+
)
48+
49+
batch_size = embeddings.size(0)
50+
num_classes_batch = batch_size // torch.unique(labels).size(0)
51+
52+
mask = 1.0 - torch.eye(batch_size)
53+
mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1)
54+
55+
sims = self.distance(embeddings)
56+
57+
sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1)
58+
sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1)
59+
sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device)
60+
sims_ranks = torch.sum(sims_sigm, dim=-1) + 1
61+
62+
xs = embeddings.view(
63+
num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1)
64+
)
65+
pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch)
66+
pos_mask = (
67+
pos_mask.unsqueeze(dim=0)
68+
.unsqueeze(dim=0)
69+
.repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1)
70+
)
71+
72+
# Circumvent the shape check in forward method
73+
xs_norm = self.distance.maybe_normalize(xs, dim=-1)
74+
sims_pos = self.distance.compute_mat(xs_norm, xs_norm)
75+
76+
sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat(
77+
1, 1, batch_size // num_classes_batch, 1
78+
)
79+
sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2)
80+
81+
sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to(
82+
sims_diff.device
83+
)
84+
sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1
85+
86+
g = batch_size // num_classes_batch
87+
ap = torch.zeros(batch_size).to(embeddings.device)
88+
for i in range(num_classes_batch):
89+
for j in range(g):
90+
pos_rank = sims_pos_ranks[i, j]
91+
all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g]
92+
ap[i * g + j] = torch.sum(pos_rank / all_rank) / g
93+
94+
miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype)
95+
loss = (1 - ap) * miner_weights
96+
97+
return {
98+
"ap_loss": {
99+
"losses": loss,
100+
"indices": c_f.torch_arange_from_size(loss),
101+
"reduction_type": "element",
102+
}
103+
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from pytorch_metric_learning.losses import SmoothAPLoss
7+
8+
from .. import TEST_DEVICE, TEST_DTYPES
9+
10+
HYPERPARAMETERS = {
11+
"temp": 0.01,
12+
"batch_size": 60,
13+
"num_id": 6,
14+
"feat_dims": 256,
15+
}
16+
TEST_SEEDS = [42, 1234, 5642, 9999, 3459]
17+
18+
19+
# Original implementation of the SmoothAP loss taken from:
20+
# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py
21+
def sigmoid(tensor, temp=1.0):
22+
"""temperature controlled sigmoid
23+
24+
takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
25+
"""
26+
exponent = -tensor / temp
27+
# clamp the input tensor for stability
28+
exponent = torch.clamp(exponent, min=-50, max=50)
29+
y = 1.0 / (1.0 + torch.exp(exponent))
30+
return y
31+
32+
33+
def compute_aff(x):
34+
"""computes the affinity matrix between an input vector and itself"""
35+
return torch.mm(x, x.t())
36+
37+
38+
class SmoothAP(torch.nn.Module):
39+
"""PyTorch implementation of the Smooth-AP loss.
40+
41+
implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
42+
the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
43+
have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
44+
45+
e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
46+
47+
labels = ( A, A, A, B, B, B, C, C, C)
48+
49+
(the order of the classes however does not matter)
50+
51+
For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
52+
mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
53+
same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
54+
55+
Args:
56+
anneal : float
57+
the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
58+
results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
59+
batch_size : int
60+
the batch size being used during training.
61+
num_id : int
62+
the number of different classes that are represented in the batch.
63+
feat_dims : int
64+
the dimension of the input feature embeddings
65+
66+
Shape:
67+
- Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
68+
- Output: scalar
69+
70+
Examples::
71+
72+
>>> loss = SmoothAP(0.01, 60, 6, 256)
73+
>>> input = torch.randn(60, 256, requires_grad=True).to("cuda:0")
74+
>>> output = loss(input)
75+
>>> output.backward()
76+
"""
77+
78+
def __init__(self, anneal, batch_size, num_id, feat_dims):
79+
"""
80+
Parameters
81+
----------
82+
anneal : float
83+
the temperature of the sigmoid that is used to smooth the ranking function
84+
batch_size : int
85+
the batch size being used
86+
num_id : int
87+
the number of different classes that are represented in the batch
88+
feat_dims : int
89+
the dimension of the input feature embeddings
90+
"""
91+
super(SmoothAP, self).__init__()
92+
93+
assert batch_size % num_id == 0
94+
95+
self.anneal = anneal
96+
self.batch_size = batch_size
97+
self.num_id = num_id
98+
self.feat_dims = feat_dims
99+
100+
def forward(self, preds):
101+
"""Forward pass for all input predictions: preds - (batch_size x feat_dims)"""
102+
103+
# ------ differentiable ranking of all retrieval set ------
104+
# compute the mask which ignores the relevance score of the query to itself
105+
mask = 1.0 - torch.eye(self.batch_size)
106+
mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
107+
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
108+
sim_all = compute_aff(preds)
109+
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
110+
# compute the difference matrix
111+
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
112+
# pass through the sigmoid
113+
sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.to(TEST_DEVICE)
114+
# compute the rankings
115+
sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
116+
117+
# ------ differentiable ranking of only positive set in retrieval set ------
118+
# compute the mask which only gives non-zero weights to the positive set
119+
xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
120+
pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
121+
pos_mask = (
122+
pos_mask.unsqueeze(dim=0)
123+
.unsqueeze(dim=0)
124+
.repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
125+
)
126+
127+
# compute the relevance scores
128+
sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
129+
sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(
130+
1, 1, int(self.batch_size / self.num_id), 1
131+
)
132+
# compute the difference matrix
133+
sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
134+
# pass through the sigmoid
135+
sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.to(TEST_DEVICE)
136+
# compute the rankings of the positive set
137+
sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
138+
139+
# sum the values of the Smooth-AP for all instances in the mini-batch
140+
ap = torch.zeros(1).to(TEST_DEVICE)
141+
group = int(self.batch_size / self.num_id)
142+
for ind in range(self.num_id):
143+
pos_divide = torch.sum(
144+
sim_pos_rk[ind]
145+
/ (
146+
sim_all_rk[
147+
(ind * group) : ((ind + 1) * group),
148+
(ind * group) : ((ind + 1) * group),
149+
]
150+
)
151+
)
152+
ap = ap + ((pos_divide / group) / self.batch_size)
153+
154+
return 1 - ap
155+
156+
157+
class TestSmoothAPLoss(unittest.TestCase):
158+
def test_smooth_ap_loss(self):
159+
for dtype in TEST_DTYPES:
160+
for seed in TEST_SEEDS:
161+
torch.manual_seed(seed)
162+
loss = SmoothAP(
163+
HYPERPARAMETERS["temp"],
164+
HYPERPARAMETERS["batch_size"],
165+
HYPERPARAMETERS["num_id"],
166+
HYPERPARAMETERS["feat_dims"],
167+
)
168+
rand_tensor = (
169+
torch.randn(
170+
HYPERPARAMETERS["batch_size"],
171+
HYPERPARAMETERS["feat_dims"],
172+
requires_grad=True,
173+
)
174+
.to(TEST_DEVICE)
175+
.to(dtype)
176+
)
177+
# The original code uses a model that normalizes the output vector
178+
input_ = F.normalize(rand_tensor, p=2.0, dim=-1)
179+
output = loss(input_)
180+
181+
loss2 = SmoothAPLoss(temperature=HYPERPARAMETERS["temp"])
182+
# The original code assumes the label is in this format
183+
labels = []
184+
for i in range(
185+
HYPERPARAMETERS["batch_size"] // HYPERPARAMETERS["num_id"]
186+
):
187+
labels.extend([i for _ in range(HYPERPARAMETERS["num_id"])])
188+
189+
labels = torch.tensor(labels)
190+
output2 = loss2.forward(rand_tensor, labels)
191+
self.assertTrue(torch.isclose(output, output2))

0 commit comments

Comments
 (0)