Skip to content

Conversation

@ssmmnn11
Copy link
Member

@ssmmnn11 ssmmnn11 commented Jun 27, 2025

Implement the multiscale loss, as proposed in
https://arxiv.org/abs/2506.10868

The MultiScaleLossWrapper wraps around already implemented loss capabilities and applies them to user-defined scales. The user defines "truncation matrices" to decrease the resolution at the individual scales. The following is an example loss config:

training_loss:
  _target_: anemoi.training.losses.MultiscaleLossWrapper
  truncation_path: ${hardware.paths.truncation}
  filenames: ${hardware.files.truncation_loss}
  weights:
    - 1.0
    - 1.0
  keep_batch_sharded: ${model.keep_batch_sharded}

  internal_loss:
    _target_: anemoi.training.losses.kcrps.AlmostFairKernelCRPS
    scalers: ['node_weights']
    ignore_nans: False
    no_autocast: True
    alpha: 0.95

Of note, the corresponding filenames always require a None valued added to the truncation matrices.

The multi scale loss will now be treated as the default loss for the ensemble model.

The aggregated loss of the various scale will be tracked in the validation metrics. A future PR will focus on implementing loss tracking for individual variables.


📚 Documentation preview 📚: https://anemoi-training--388.org.readthedocs.build/en/388/


📚 Documentation preview 📚: https://anemoi-graphs--388.org.readthedocs.build/en/388/


📚 Documentation preview 📚: https://anemoi-models--388.org.readthedocs.build/en/388/

@ssmmnn11 ssmmnn11 marked this pull request as draft July 4, 2025 14:53
@github-project-automation github-project-automation bot moved this to Now In Progress in Anemoi-dev Jul 10, 2025
@ssmmnn11 ssmmnn11 requested a review from JPXKQX July 12, 2025 07:16
@ssmmnn11 ssmmnn11 marked this pull request as ready for review July 17, 2025 14:23
@ssmmnn11 ssmmnn11 changed the title Feat/kcrps multi scale loss feat/kcrps multi scale loss Jul 17, 2025
@ssmmnn11 ssmmnn11 added the documentation Improvements or additions to documentation label Jul 17, 2025
Copy link
Member

@JPXKQX JPXKQX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, Simon! I have two general questions,

  1. Do you think this should be specific to the ensemble model? I am not sure if a multi-scale MSE would produce good results, but it could still be interesting to see it used as a validation metric.
  2. What are your thoughts on where this functionality should be located (as part of the forecaster or as a new loss function)? are there any performance issues we should be aware of?

@ssmmnn11
Copy link
Member Author

Thanks for the PR, Simon! I have two general questions,

  1. Do you think this should be specific to the ensemble model? I am not sure if a multi-scale MSE would produce good results, but it could still be interesting to see it used as a validation metric.

Probably does not make much sense for the single model. But could be tested, and having it as an option would not hurt.

  1. What are your thoughts on where this functionality should be located (as part of the forecaster or as a new loss function)? are there any performance issues we should be aware of?

It has a performance impact if it is used - this is unavoidable. So we should be able to switch it off. It could be a new loss function. the way it is implemented now means it can be used with any loss function though - therefore I see it more as a wrapper of a loss function, or something like that.

@mchantry mchantry added the ATS Approval Needed Approval needed by ATS label Sep 9, 2025
@theissenhelen theissenhelen changed the title feat/kcrps multi scale loss feat: kcrps multi scale loss Nov 4, 2025
@mchantry mchantry changed the title feat: kcrps multi scale loss feat: multi-scale loss implementations (including multi-scale kcrps) Nov 12, 2025
ssmmnn11 and others added 3 commits November 17, 2025 11:41
Additional fixes to make multiple scales and metrics working

---------

Co-authored-by: theissenhelen <[email protected]>
@anaprietonem anaprietonem added the ATS Approved Approved by ATS label Nov 19, 2025
@github-actions github-actions bot added the bug Something isn't working label Nov 19, 2025
@OpheliaMiralles
Copy link
Contributor

What is scale in this context? Filtering on grid points for specific areas and re-weight the total loss? Because there is a FilteringLossWrapper that was used for filtering out some variables and could be used the same for grid points, then wrapped in a CombinedLoss to associate weight to each spatial area? But maybe I didn't get exactly and it is doing more than that

@ssmmnn11
Copy link
Member Author

What is scale in this context? Filtering on grid points for specific areas and re-weight the total loss? Because there is a FilteringLossWrapper that was used for filtering out some variables and could be used the same for grid points, then wrapped in a CombinedLoss to associate weight to each spatial area? But maybe I didn't get exactly and it is doing more than that

this is for scale aware training, we explain it here: https://arxiv.org/abs/2506.10868

@OpheliaMiralles
Copy link
Contributor

What is scale in this context? Filtering on grid points for specific areas and re-weight the total loss? Because there is a FilteringLossWrapper that was used for filtering out some variables and could be used the same for grid points, then wrapped in a CombinedLoss to associate weight to each spatial area? But maybe I didn't get exactly and it is doing more than that

this is for scale aware training, we explain it here: https://arxiv.org/abs/2506.10868

I saw the paper, I just wonder how is this different than constraining the loss on the spatial dimension to a subset of gridpoints, do that for a set of truncation matrices and aggregate with weights. If it is the case, it could maybe be done with a CombinedLoss + FilteringLossWrapper, because that is how I worked with it on the variable dimension and it wouldn't change much to have truncation matrices instead of specific variable names. Did anyone try to implement it using existing code?

Copy link
Collaborator

@jakob-schloer jakob-schloer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I've left a couple of minor comments. I also think that some more tests and a small section to the documentation would be great to have for that.

loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
loss = torch.zeros(1, dtype=batch[0].dtype, device=self.device, requires_grad=False)

if self.loss.name == "MultiscaleLossWrapper":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have this if-block because we later want to track the losses for the different scales? In my opinion, we should avoid adding complexity to the Trainer class for the callbacks. Could we not move this logic to the RolloutEval class in diagnostics.callbacks.evaluation?

@JPXKQX
Copy link
Member

JPXKQX commented Nov 20, 2025

PR #670 has now been merged into main. This includes the SparseProjector class, which could be useful here. Please let me know if you would like any help with this

@mchantry mchantry removed the ATS Approval Needed Approval needed by ATS label Nov 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ATS Approved Approved by ATS bug Something isn't working documentation Improvements or additions to documentation models training

Projects

Status: Now In Progress

Development

Successfully merging this pull request may close these issues.

8 participants