Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions algoperf/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,9 @@ def model_params_types(self):
def is_output_params(self, param_key: ParameterKey) -> bool:
"""Whether a key in ParameterContainer is the output layer parameters."""

# InitModelFn = Callable[
# Tuple[RandomState, Optional[float], Optional[float]],
# ParameterContainer]
# InitModelFn = Callable[Optional[float]], ParameterContainer]
@abc.abstractmethod
def init_model_fn(
self,
rng: RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None,
) -> ModelInitState:
def init_model_fn(self, rng: RandomState) -> ModelInitState:
"""Return (initial_params, initial_model_state)."""

# ModelFn = Callable[
Expand Down
2 changes: 2 additions & 0 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def model_fn(
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None,
dropout_rate: float = 0.0,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
del dropout_rate
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
Expand Down
12 changes: 3 additions & 9 deletions algoperf/workloads/cifar/cifar_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,8 @@ def _build_dataset(
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
return dataloader

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate

if hasattr(self, '_model'):
if isinstance(self._model, (DDP, torch.nn.DataParallel)):
self._model.module.reset_parameters()
Expand Down Expand Up @@ -158,9 +150,11 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
dropout_rate: float = 0.0,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del dropout_rate
model = params
if mode == spec.ForwardPassMode.EVAL:
if update_batch_norm:
Expand Down
1 change: 0 additions & 1 deletion algoperf/workloads/fastmri/fastmri_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def init_model_fn(
self,
rng: spec.RandomState,
) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
fake_batch = jnp.zeros((13, 320, 320))
self._model = UNet(
num_pool_layers=self.num_pool_layers,
Expand Down
2 changes: 1 addition & 1 deletion algoperf/workloads/imagenet_vit/imagenet_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class ViT(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
rep_size: Union[int, bool] = True
dropout_rate: [float] = DROPOUT_RATE
dropout_rate: float = DROPOUT_RATE
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
Expand Down
2 changes: 2 additions & 0 deletions algoperf/workloads/mnist/mnist_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
dropout_rate: float = 0.0,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
del dropout_rate
train = mode == spec.ForwardPassMode.TRAIN
logits_batch = self._model.apply(
{'params': params},
Expand Down
13 changes: 3 additions & 10 deletions algoperf/workloads/mnist/mnist_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,7 @@ def shard(batch):
}
yield batch

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
if hasattr(self, '_model'):
if isinstance(self._model, (DDP, torch.nn.DataParallel)):
self._model.module.reset_parameters()
Expand Down Expand Up @@ -178,10 +169,12 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
dropout_rate: float = 0.0,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
del dropout_rate
model = params
if mode == spec.ForwardPassMode.EVAL:
model.eval()
Expand Down
11 changes: 2 additions & 9 deletions algoperf/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,7 @@ def translate_and_calculate_bleu(
bleu_score = bleu.corpus_bleu(predictions, [references]).score
return bleu_score

def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
"""aux_dropout_rate is used as attention_dropout_rate."""

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_fake_batch_size = 8
input_shape = (init_fake_batch_size, 256)
target_shape = (init_fake_batch_size, 256)
Expand Down Expand Up @@ -295,7 +288,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
dropout_rate: [float] = models.DROPOUT_RATE,
dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
Expand Down
11 changes: 4 additions & 7 deletions docs/DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,7 @@ def _build_input_queue(
###### Model initialization

```python
def init_model_fn(
self,
rng: RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None
) -> initial model parameters
def init_model_fn(self, rng: RandomState) -> initial model parameters
```

- Unlike in the *Model Track*, this function that initializes the parameters of the model, is fixed. While it can be called by the submission (e.g. to restart the model after a failed training effort) it cannot be changed.
Expand All @@ -125,7 +120,8 @@ def model_fn(
mode: ForwardPassMode, # mode \in {train, eval}
rng: RandomState,
hyperparameters: Hyperparameters,
update_batch_norm: bool
update_batch_norm: bool,
dropout_rate: float
) -> (logits_output_batch, new_model_state): Tuple[Tensor, ModelAuxiliaryState]
```

Expand All @@ -134,6 +130,7 @@ def model_fn(
- `logits_output_batch` is before the output activation
- `new_model_state` is for batch norm or similar side effects and will only be updated if `update_batch_norm` is set
- `hyperparameters` will contain only dropout rates, which will be used in the models that support it. These can be tuned or will default to documented model-specific values. Note that adding additional dropout would be considered changing the model, which is not allowed, but the tuning of dropout in existing dropout layers can be considered a regularizer, so we allow it. There should be at most two dropout rates in a model (if there are more than two we will reuse the same values).
- `dropout_rate` is used in the model forward pass. If not provided, the workload’s default value is used (see below for the list of defaults).

###### Loss function

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def pmapped_train_step(
rng,
grad_clip,
label_smoothing,
dropout_rate,
):
def _loss_fn(params):
"""Loss function used for training."""
Expand All @@ -94,6 +95,7 @@ def _loss_fn(params):
logits_batch=logits,
mask_batch=batch.get('weights'),
label_smoothing=label_smoothing,
dropout_rate=dropout_rate,
)
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
Expand Down Expand Up @@ -156,6 +158,7 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
dropout_rate = hyperparameters.dropout_rate
outputs = pmapped_train_step(
workload,
opt_update_fn,
Expand All @@ -166,6 +169,7 @@ def update_params(
per_device_rngs,
grad_clip,
label_smoothing,
dropout_rate,
)
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def update_params(
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True,
dropout_rate=hyperparameters.dropout_rate,
)

label_smoothing = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def train_step(
rng,
grad_clip,
label_smoothing,
dropout_rate,
):
def _loss_fn(params):
"""Loss function used for training."""
Expand All @@ -79,6 +80,7 @@ def _loss_fn(params):
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True,
dropout_rate=dropout_rate,
)
loss_dict = workload.loss_fn(
label_batch=batch['targets'],
Expand Down Expand Up @@ -144,6 +146,7 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
dropout_rate = hyperparameters.dropout_rate

# Set up mesh and sharding
mesh = jax.sharding.Mesh(jax.devices(), ('batch'))
Expand All @@ -164,6 +167,7 @@ def update_params(
replicated, # rng
replicated, # grad_clip
replicated, # label_smoothing
replicated, # dropout_rate
),
out_shardings=(
replicated, # new_optimizer_state
Expand All @@ -185,6 +189,7 @@ def update_params(
rng,
grad_clip,
label_smoothing,
dropout_rate,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def update_params(
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True,
dropout_rate=hyperparameters.dropout_rate,
)

label_smoothing = (
Expand Down
4 changes: 4 additions & 0 deletions reference_algorithms/paper_baselines/lamb/jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def pmapped_train_step(
rng,
grad_clip,
label_smoothing,
dropout_rate,
):
def _loss_fn(params):
"""Loss function used for training."""
Expand All @@ -94,6 +95,7 @@ def _loss_fn(params):
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True,
dropout_rate=dropout_rate,
)
loss_dict = workload.loss_fn(
label_batch=batch['targets'],
Expand Down Expand Up @@ -163,6 +165,7 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
dropout_rate = hyperparameters.dropout_rate
outputs = pmapped_train_step(
workload,
opt_update_fn,
Expand All @@ -173,6 +176,7 @@ def update_params(
per_device_rngs,
grad_clip,
label_smoothing,
dropout_rate,
)
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def update_params(
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True,
dropout_rate=hyperparameters.dropout_rate,
)

label_smoothing = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def train_step(
rng,
grad_clip,
label_smoothing,
dropout_rate,
):
def _loss_fn(params):
"""Loss function used for training."""
Expand All @@ -113,6 +114,7 @@ def _loss_fn(params):
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True,
dropout_rate=dropout_rate,
)
loss_dict = workload.loss_fn(
label_batch=batch['targets'],
Expand Down Expand Up @@ -177,6 +179,7 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
dropout_rate = hyperparameters.dropout_rate

# Create shardings for each argument
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
Expand All @@ -195,6 +198,7 @@ def update_params(
replicated, # rng
replicated, # grad_clip
replicated, # label_smoothing
replicated, # dropout_rate
)
out_shardings = (
replicated, # new_optimizer_state
Expand Down Expand Up @@ -223,6 +227,7 @@ def update_params(
rng,
grad_clip,
label_smoothing,
dropout_rate,
)
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def update_params(
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True,
dropout_rate=hyperparameters.dropout_rate,
)

label_smoothing = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def train_step(
rng,
grad_clip,
label_smoothing,
dropout_rate,
):
def _loss_fn(params):
"""Loss function used for training."""
Expand All @@ -225,6 +226,7 @@ def _loss_fn(params):
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True,
dropout_rate=dropout_rate,
)
loss_dict = workload.loss_fn(
label_batch=batch['targets'],
Expand Down Expand Up @@ -289,6 +291,8 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
dropout_rate = hyperparameters.dropout_rate

# Create shardings for each argument
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
sharded = (
Expand All @@ -306,6 +310,7 @@ def update_params(
replicated, # rng
replicated, # grad_clip
replicated, # label_smoothing
replicated, # dropout_rate
)
out_shardings = (
replicated, # new_optimizer_state
Expand Down Expand Up @@ -335,6 +340,7 @@ def update_params(
rng,
grad_clip,
label_smoothing,
dropout_rate,
)
)

Expand Down
Loading