Skip to content

Commit 6d17688

Browse files
Error in model, scaling only q matrix not qK.T dot product (qk.T/sqrt(dim_per_head)
As per Vaswani et al, 2017 p.4 Is torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) not q / math.sqrt(dim_per_head) https://arxiv.org/pdf/1912.05372.pdf
1 parent ded1cf8 commit 6d17688

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

xlm/model/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ def unshape(x):
208208
k, v = cache[self.layer_id]
209209
cache[self.layer_id] = (k, v)
210210

211-
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
212211
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
212+
scores = scores / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
213213
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
214214
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
215215

0 commit comments

Comments
 (0)