Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions training/src/anemoi/training/losses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,11 @@ def scale(
else:
scale_tensor = self.scaler.without_by_dim(without_scalers)

scaler = scale_tensor.get_scaler(x.ndim)

if grid_shard_slice is not None:
scaler = scaler[:, :, grid_shard_slice, :]

scaler = scaler.expand_as(x)

return x[subset_indices] * scaler[subset_indices]
return scale_tensor.scale_iteratively(
x,
subset_indices=subset_indices,
grid_shard_slice=grid_shard_slice,
)

def reduce(
self,
Expand Down
66 changes: 64 additions & 2 deletions training/src/anemoi/training/losses/scaler_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch import nn

from anemoi.training.utils.enums import TensorDim

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Sequence
Expand Down Expand Up @@ -515,7 +517,60 @@ def resolve(self, ndim: int) -> ScaleTensor:

return ScaleTensor(**resolved_scalers)

def scale(self, tensor: torch.Tensor) -> torch.Tensor:
def scale_iteratively(
self,
x: torch.Tensor,
subset_indices: tuple[int, ...] | None = None,
*,
grid_shard_slice: slice | None = None,
) -> None:
"""Apply the scalers iteratively to the input tensor.

Parameters
----------
x : torch.Tensor
Input tensor to scale
subset_indices : tuple[int, ...] | None, optional
Indices to subset the input tensor, by default None
grid_shard_slice : slice | None, optional
Slice to apply to the grid dimension, by default None
"""
x_subset = x[subset_indices] if subset_indices is not None else x
out = x_subset.clone()
ndim = x.ndim
tensors = self.resolve(ndim).tensors

for dims, scaler in tensors.values():
if TensorDim.GRID in dims and grid_shard_slice is not None:
grid_index = dims.index(TensorDim.GRID)
if scaler.shape[grid_index] > 1:
slices = [slice(None)] * len(dims)
slices[grid_index] = grid_shard_slice
scaler = scaler[tuple(slices)]

missing_dims = [d for d in range(ndim) if d not in dims]
reshape = [1] * len(missing_dims)
reshape.extend(scaler.shape)

reshaped_scaler = scaler.reshape(reshape)
reshaped_scaler = torch.moveaxis(reshaped_scaler, list(range(ndim)), (*missing_dims, *dims))

reshaped_scaler = reshaped_scaler.expand_as(x)

if subset_indices is not None:
reshaped_scaler = reshaped_scaler[subset_indices]

out = out * reshaped_scaler

return out

def scale(
self,
x: torch.Tensor,
subset_indices: tuple[int, ...] | None = None,
*,
grid_shard_slice: slice | None = None,
) -> None:
"""Scale a given tensor by the scalers.

Parameters
Expand All @@ -528,7 +583,14 @@ def scale(self, tensor: torch.Tensor) -> torch.Tensor:
torch.Tensor
Scaled tensor
"""
return tensor * self.get_scaler(tensor.ndim, device=tensor.device)
x_subset = x[subset_indices] if subset_indices is not None else x
scaler = self.get_scaler(x_subset.ndim)
if grid_shard_slice is not None and scaler.shape[TensorDim.GRID] > 1:
slices = [slice(None)] * x_subset.ndim
slices[TensorDim.GRID] = grid_shard_slice
scaler = scaler[tuple(slices)]

return x_subset * scaler

def get_scaler(self, ndim: int, device: str | None = None) -> torch.Tensor:
"""Get completely resolved scaler tensor.
Expand Down
2 changes: 2 additions & 0 deletions training/tests/unit/train/test_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_scale_tensor_one_dim(
scale.add_scaler(*scaler)

torch.testing.assert_close(scale.scale(input_tensor), output)
torch.testing.assert_close(scale.scale_iteratively(input_tensor), output)


def test_invalid_dim_sizes() -> None:
Expand Down Expand Up @@ -173,6 +174,7 @@ def test_scale_tensor_two_dim(
output = torch.tensor(output, dtype=torch.float32)

torch.testing.assert_close(scale.scale(input_tensor), output)
torch.testing.assert_close(scale.scale_iteratively(input_tensor), output)


@pytest.mark.parametrize("subset_id", ["test", 0])
Expand Down