-
Notifications
You must be signed in to change notification settings - Fork 68
feat: multi-scale loss implementations (including multi-scale kcrps) #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
JPXKQX
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, Simon! I have two general questions,
- Do you think this should be specific to the ensemble model? I am not sure if a multi-scale MSE would produce good results, but it could still be interesting to see it used as a validation metric.
- What are your thoughts on where this functionality should be located (as part of the forecaster or as a new loss function)? are there any performance issues we should be aware of?
Probably does not make much sense for the single model. But could be tested, and having it as an option would not hurt.
It has a performance impact if it is used - this is unavoidable. So we should be able to switch it off. It could be a new loss function. the way it is implemented now means it can be used with any loss function though - therefore I see it more as a wrapper of a loss function, or something like that. |
Additional fixes to make multiple scales and metrics working --------- Co-authored-by: theissenhelen <[email protected]>
|
What is scale in this context? Filtering on grid points for specific areas and re-weight the total loss? Because there is a FilteringLossWrapper that was used for filtering out some variables and could be used the same for grid points, then wrapped in a CombinedLoss to associate weight to each spatial area? But maybe I didn't get exactly and it is doing more than that |
this is for scale aware training, we explain it here: https://arxiv.org/abs/2506.10868 |
I saw the paper, I just wonder how is this different than constraining the loss on the spatial dimension to a subset of gridpoints, do that for a set of truncation matrices and aggregate with weights. If it is the case, it could maybe be done with a CombinedLoss + FilteringLossWrapper, because that is how I worked with it on the variable dimension and it wouldn't change much to have truncation matrices instead of specific variable names. Did anyone try to implement it using existing code? |
jakob-schloer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I've left a couple of minor comments. I also think that some more tests and a small section to the documentation would be great to have for that.
| loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) | ||
| loss = torch.zeros(1, dtype=batch[0].dtype, device=self.device, requires_grad=False) | ||
|
|
||
| if self.loss.name == "MultiscaleLossWrapper": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have this if-block because we later want to track the losses for the different scales? In my opinion, we should avoid adding complexity to the Trainer class for the callbacks. Could we not move this logic to the RolloutEval class in diagnostics.callbacks.evaluation?
|
PR #670 has now been merged into main. This includes the |
Implement the multiscale loss, as proposed in
https://arxiv.org/abs/2506.10868
The MultiScaleLossWrapper wraps around already implemented loss capabilities and applies them to user-defined scales. The user defines "truncation matrices" to decrease the resolution at the individual scales. The following is an example loss config:
Of note, the corresponding filenames always require a
Nonevalued added to the truncation matrices.The multi scale loss will now be treated as the default loss for the ensemble model.
The aggregated loss of the various scale will be tracked in the validation metrics. A future PR will focus on implementing loss tracking for individual variables.
📚 Documentation preview 📚: https://anemoi-training--388.org.readthedocs.build/en/388/
📚 Documentation preview 📚: https://anemoi-graphs--388.org.readthedocs.build/en/388/
📚 Documentation preview 📚: https://anemoi-models--388.org.readthedocs.build/en/388/