Skip to content

Commit 8642cca

Browse files
Merge pull request #887 from mlcommons/dropout_fixes
fix for wmt
2 parents 24d9815 + 9325826 commit 8642cca

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

algoperf/workloads/wmt/wmt_jax/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
223223
bias_init=cfg.bias_init,
224224
use_bias=False,
225225
broadcast_dropout=False,
226-
dropout_rate=0.0,
226+
dropout_rate=0.0, # The dropout is applied at the end of this layer
227227
deterministic=cfg.deterministic,
228228
)(cfg.attention_temp * x, x, mask=encoder_mask)
229229

@@ -286,7 +286,7 @@ def __call__(
286286
bias_init=cfg.bias_init,
287287
use_bias=False,
288288
broadcast_dropout=False,
289-
dropout_rate=dropout_rate,
289+
dropout_rate=0.0, # Dropout applied after MultiHeadDotProductAttention
290290
deterministic=cfg.deterministic,
291291
decode=cfg.decode,
292292
)(cfg.attention_temp * x, x, mask=decoder_mask)

0 commit comments

Comments
 (0)