Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
142fa66
multi loss implementation
ssmmnn11 Apr 25, 2025
4b104b8
config update
ssmmnn11 Apr 25, 2025
08eb3f5
fix logging
ssmmnn11 Apr 25, 2025
f51937d
example setup in debug_ens
ssmmnn11 Jun 18, 2025
05a9d42
Merge branch 'kcrps_mloss' into feat/kcrps-multi-scale-loss
ssmmnn11 Jun 25, 2025
8f1ca0c
fix for channel sharding
ssmmnn11 Jun 25, 2025
d61ea5f
Merge branch 'fix-channel-sharding' into feat/kcrps-multi-scale-loss
ssmmnn11 Jun 25, 2025
1497ddc
multi-scale loss improvements
ssmmnn11 Jun 27, 2025
fa04281
Merge remote-tracking branch 'origin/main' into feat/kcrps-multi-scal…
ssmmnn11 Jul 17, 2025
765b53e
pydantic and some documentation
ssmmnn11 Jul 17, 2025
9c33ce9
docu update
ssmmnn11 Jul 17, 2025
0417e77
fix
ssmmnn11 Jul 17, 2025
652507d
Merge branch 'main' into feat/kcrps-multi-scale-loss
ssmmnn11 Jul 25, 2025
7bf3e8f
merged main
ssmmnn11 Sep 26, 2025
ad42367
fix for single GPU training
ssmmnn11 Sep 26, 2025
0f64a73
fix
ssmmnn11 Sep 26, 2025
35f04fe
fix: skip sharding when running on single gpu
theissenhelen Oct 16, 2025
27ff0c5
refactor: factor truncation operations out of model
theissenhelen Oct 16, 2025
83e17bb
more refactoring
theissenhelen Oct 16, 2025
e2a31f6
WIP
theissenhelen Oct 23, 2025
4e04123
instantiation of multiscale working
theissenhelen Oct 23, 2025
47ccbcc
MultiscaleLoss working
theissenhelen Oct 24, 2025
6e3725c
WIP
theissenhelen Nov 4, 2025
f1dcdfd
use kwargs for multiscale
theissenhelen Nov 4, 2025
e8aa6f5
Schema for multiscale loss
theissenhelen Nov 4, 2025
5634add
add multiscale to configs
theissenhelen Nov 4, 2025
cf63ffa
Merge remote-tracking branch 'origin/main' into feat/kcrps-multi-scal…
theissenhelen Nov 4, 2025
10b92d4
fix: mscale loss (#661)
ssmmnn11 Nov 17, 2025
e79cf27
remove unused entries
theissenhelen Nov 17, 2025
bd2c919
Merge remote-tracking branch 'origin/main' into feat/kcrps-multi-scal…
theissenhelen Nov 17, 2025
8a907ab
fix mloss accum missing
theissenhelen Nov 17, 2025
2c7cbc2
add truncation to integration tests
theissenhelen Nov 17, 2025
f44037a
adjust weights
theissenhelen Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions models/src/anemoi/models/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import numpy as np
import torch


def make_truncation_matrix(A, data_type=torch.float32):
A_ = torch.sparse_coo_tensor(
torch.tensor(np.vstack(A.nonzero()), dtype=torch.long),
torch.tensor(A.data, dtype=data_type),
size=A.shape,
).coalesce()
return A_


def truncate_fields(x, A, batch_size=None, auto_cast=False):
if not batch_size:
batch_size = x.shape[0]
out = []
with torch.amp.autocast(device_type="cuda", enabled=auto_cast):
for i in range(batch_size):
out.append(multiply_sparse(x[i, ...], A))
return torch.stack(out)


def multiply_sparse(x, A):
if torch.cuda.is_available():
with torch.amp.autocast(device_type="cuda", enabled=False):
out = torch.sparse.mm(A, x)
else:
with torch.amp.autocast(device_type="cpu", enabled=False):
out = torch.sparse.mm(A, x)
return out


def interpolate_batch(batch: torch.Tensor, intp_matrix: torch.Tensor) -> torch.Tensor:
input_shape = batch.shape # e.g. (batch steps ensemble grid vars) or (batch steps grid vars)
batch = batch.reshape(-1, *input_shape[-2:])
batch = truncate_fields(batch, intp_matrix) # to coarse resolution
return batch.reshape(*input_shape)
46 changes: 34 additions & 12 deletions training/docs/user-guide/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,44 @@ For detailed information and examples, see
******************

Field truncation is a pre-processing step applied during autoregressive
rollout. It smooths the input data which helps maintain stability during
rollout.
rollout. It smooths the skipped connection data which helps maintain
stability during rollout and can be used for multi-scale loss
computation.

The truncation process relies on pre-computed transformation matrices
which can be specified in the configuration:
**********
Overview
**********

.. code:: yaml
Truncation matrices are sparse transformation matrices that filter
high-frequency components from the input data. This process serves two
main purposes:

#. **Stability Enhancement**: Smoothing the skipped connection data
helps maintain numerical stability during long autoregressive
rollouts by reducing noise amplification.

#. **Multi-scale Loss Computation**: For ensemble training, truncation
matrices can be used to compute losses at different scales.

**************
Matrix Types
**************

The truncation system supports several types of transformation matrices:

**Truncation Matrix (``truncation``)**
The forward transformation matrix that applies the truncation filter
to the skipped connection.

path:
truncation: /path/to/truncation/matrix
files:
truncation: truncation_matrix.pt
truncation_inv: truncation_matrix_inv.pt
**Inverse Truncation Matrix (``truncation_inv``)**
The inverse transformation matrix.

Once set, the truncation matrices are used automatically during the
rollout.
**Loss Truncation Matrices (``truncation_loss``)**
A list of matrices used for multi-scale loss computation during
ensemble training only. Each matrix corresponds to a different scale
for loss evaluation. These need to be ordered so that the first
matrix corresponds to the largest scales. The following matrices then
include smaller and smaller scales.

.. note::

Expand Down
4 changes: 2 additions & 2 deletions training/docs/user-guide/yaml/example_crps_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ config_validation: True
# Changes in hardware
hardware:
files:
truncation: ${data.resolution}-O32-linear.mat.npz
truncation_inv: O32-${data.resolution}-linear.mat.npz
truncation: ${data.resolution}-o32-linear.mat.npz
truncation_inv: o32-${data.resolution}-linear.mat.npz
num_gpus_per_ensemble: 1
num_gpus_per_node: 1
num_nodes: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ dataset: ???
graph: ???
truncation: null
truncation_inv: null
truncation_loss: [False]
checkpoint:
every_n_epochs: anemoi-by_epoch-epoch_{epoch:03d}-step_{step:06d}
every_n_train_steps: anemoi-by_step-epoch_{epoch:03d}-step_{step:06d}
Expand Down
67 changes: 53 additions & 14 deletions training/src/anemoi/training/config/training/ensemble.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,33 @@ strategy:
# don't enable this by default until it's been tested and proven beneficial
loss_gradient_scaling: False


# loss function for the model
# To train without multiscale loss, set it to the desired loss directly
training_loss:
# loss class to initialise, can be anything subclassing torch.nn.Module
_target_: anemoi.training.losses.kcrps.AlmostFairKernelCRPS
# Scalers to include in loss calculation
# A selection of available scalers are listed in training/scalers.
# '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded
# add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added.
scalers: ['pressure_level', 'general_variable', 'node_weights']
# other kwargs
ignore_nans: False
alpha: 1.0
_target_: anemoi.training.losses.MultiscaleLossWrapper
truncation_path: ${hardware.paths.truncation}
filenames: ${hardware.files.truncation_loss}
weights:
- 1.0
- 1.0
keep_batch_sharded: ${model.keep_batch_sharded}

internal_loss:
_target_: anemoi.training.losses.kcrps.AlmostFairKernelCRPS
scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']

# Scalers to include in loss calculation
# A selection of available scalers are listed in training/scalers.
# '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded
# add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added.
# scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
ignore_nans: False
no_autocast: True
alpha: 0.95



# Validation metrics calculation,
# This may be a list, in which case all metrics will be calculated
Expand All @@ -84,11 +99,35 @@ training_loss:
# have undergone postprocessing.
validation_metrics:
# loss class to initialise, can be anything subclassing torch.nn.Module
fkcrps:
_target_: anemoi.training.losses.kcrps.AlmostFairKernelCRPS
scalers: ['node_weights']
ignore_nans: False
alpha: 1.0
# fkcrps:
# _target_: anemoi.training.losses.kcrps.AlmostFairKernelCRPS
# scalers: ['node_weights']
# ignore_nans: False
# alpha: 1.0

multiscale:
_target_: anemoi.training.losses.MultiscaleLossWrapper

truncation_path: ${hardware.paths.truncation}
filenames: ${hardware.files.truncation_loss}
keep_batch_sharded: ${model.keep_batch_sharded}
weights:
- 1.0
- 1.0

internal_loss:
_target_: anemoi.training.losses.kcrps.AlmostFairKernelCRPS
scalers: ['node_weights']

# Scalers to include in loss calculation
# A selection of available scalers are listed in training/scalers.
# '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded
# add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added.
# scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
ignore_nans: False
no_autocast: True
alpha: 1.0


# Variable groups definition for scaling
# The variable level scaling methods are defined under training/scalers
Expand Down
3 changes: 3 additions & 0 deletions training/src/anemoi/training/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .loss import get_loss_function
from .mae import MAELoss
from .mse import MSELoss
from .multiscale import MultiscaleLossWrapper
from .rmse import RMSELoss
from .weighted_mse import WeightedMSELoss

Expand All @@ -26,6 +27,8 @@
"LogCoshLoss",
"MAELoss",
"MSELoss",
"MultiscaleLoss",
"MultiscaleLossWrapper",
"RMSELoss",
"WeightedMSELoss",
"get_loss_function",
Expand Down
2 changes: 2 additions & 0 deletions training/src/anemoi/training/losses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def forward(
without_scalers: list[str] | list[int] | None = None,
grid_shard_slice: slice | None = None,
group: ProcessGroup | None = None,
**kwargs,
) -> torch.Tensor:
"""Calculates the area-weighted scaled loss.

Expand Down Expand Up @@ -255,6 +256,7 @@ def forward(
without_scalers: list[str] | list[int] | None = None,
grid_shard_slice: slice | None = None,
group: ProcessGroup | None = None,
**kwargs, # noqa: ARG002
) -> torch.Tensor:
"""Calculates the area-weighted scaled loss.

Expand Down
2 changes: 2 additions & 0 deletions training/src/anemoi/training/losses/kcrps.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def forward(
without_scalers: list[str] | list[int] | None = None,
grid_shard_slice: slice | None = None,
group: ProcessGroup | None = None,
**kwargs, # noqa: ARG002
) -> torch.Tensor:
is_sharded = grid_shard_slice is not None

Expand Down Expand Up @@ -174,6 +175,7 @@ def forward(
without_scalers: list[str] | list[int] | None = None,
grid_shard_slice: slice | None = None,
group: ProcessGroup | None = None,
**kwargs, # noqa: ARG002
) -> torch.Tensor:
is_sharded = grid_shard_slice is not None

Expand Down
22 changes: 19 additions & 3 deletions training/src/anemoi/training/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from anemoi.training.utils.variables_metadata import ExtractVariableGroupAndLevel

METRIC_RANGE_DTYPE = dict[str, list[int]]

NESTED_LOSSES = ["anemoi.training.losses.MultiscaleLossWrapper"]
LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -65,18 +67,34 @@ def get_loss_function(
loss_config = OmegaConf.to_container(config, resolve=True)
scalers_to_include = loss_config.pop("scalers", [])

if "_target_" in loss_config and loss_config["_target_"] in NESTED_LOSSES:
internal_loss_config = loss_config.pop("internal_loss")
internal_loss = get_loss_function(OmegaConf.create(internal_loss_config), scalers, data_indices)
return instantiate(loss_config, internal_loss=internal_loss, **kwargs)

if scalers is None:
scalers = {}

if "*" in scalers_to_include:
scalers_to_include = [s for s in list(scalers.keys()) if f"!{s}" not in scalers_to_include]

# Instantiate the loss function with the loss_init_config
loss_function = instantiate(loss_config, **kwargs, _recursive_=False)

if not isinstance(loss_function, BaseLoss):
error_msg = f"Loss must be a subclass of 'BaseLoss', not {type(loss_function)}"
raise TypeError(error_msg)
_apply_scalers(loss_function, scalers_to_include, scalers, data_indices)

return loss_function


def _apply_scalers(
loss_function: BaseLoss,
scalers_to_include: list,
scalers: dict[str, TENSOR_SPEC] | None,
data_indices: dict | None,
) -> None:
"""Attach scalers to a loss function and set data indices if needed."""
for key in scalers_to_include:
if key not in scalers or []:
error_msg = f"Scaler {key!r} not found in valid scalers: {list(scalers.keys())}"
Expand All @@ -93,8 +111,6 @@ def get_loss_function(
if hasattr(loss_function, "set_data_indices"):
loss_function.set_data_indices(data_indices)

return loss_function


def _get_metric_ranges(
extract_variable_group_and_level: ExtractVariableGroupAndLevel,
Expand Down
Loading
Loading