-
Notifications
You must be signed in to change notification settings - Fork 67
feat: introduce spectral losses module #678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I wouldn't say it starts the work on spectral loss since there are some already implemented in the anemoi-core losses... |
Starts the work for a generic module for spectral losses, which has a broader goal of defining abstractions. But yes absolutely, it also draws on what's already implemented in the |
I assigned myself as a reviewer so I'm willing to contribute and give feedback :) |
7a90c1e to
4154104
Compare
| return (pred - target) ** 2 | ||
|
|
||
|
|
||
| class FourierCorrelationLoss(SpectralLoss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also add the log version of the FFT2 norm
Following #599 and after discussing with @theissenhelen I am opening this PR to start the work on a generic spectral losses module. The focus of this PR is to establish the design of the module and aims to define the right abstractions and tests. Only very basic spectral losses will be actually implemented here for testing purposes. The work will also build on the existing spectral losses implemented in spatial.py. Feedback is of course more than welcome!
Design proposal
A main concern when designing spectral losses is the separation of the logic of the transformation to spectral domain from the logic for computing the error metric itself. Therefore, we introduce two respective abstract classes,
SpectralTransformandSpectralLossfor this goal. The former is then assigned as thetransformattribute of the latter.SpectralTransform: This class requires all subclasses to implement the__call__()method. As an example, I already implementedFFT2Dby inheriting from theSpectralTransformclass.SHTis added as a placeholder. Note that this is temporarily implemented inside the spectral losses module, but might be moved at some point intoanemoi-models. An important assumption of this class is that the input and output tensors will always be assumed to be of shape [batch, time, [ensemble], points, variable], so any reshaping required for the transform happens inside and is undone before returning the value. It's still unclear to me whether this generalizes well (see second point in open questions).SpectralLoss: All subclasses that directly inherit from this class will have to implement theforward()method, which will contain the entire logic for the loss computation, including scaling and reduction. For now, only one loss has been implemented (ported from thespatial.pymodule).FunctionalSpectralLoss: Following an existing pattern in thelossespackage, we also introduce aFunctionalSpectralLossclass that inherits from bothFunctionalLossandSpectralLossand allows for substantial code deduplication. An example implementation is implemented asSpectralL2Loss.Open questions