-
Notifications
You must be signed in to change notification settings - Fork 68
feat: multi-scale loss implementations (including multi-scale kcrps) #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ssmmnn11
wants to merge
33
commits into
main
Choose a base branch
from
feat/kcrps-multi-scale-loss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
142fa66
multi loss implementation
ssmmnn11 4b104b8
config update
ssmmnn11 08eb3f5
fix logging
ssmmnn11 f51937d
example setup in debug_ens
ssmmnn11 05a9d42
Merge branch 'kcrps_mloss' into feat/kcrps-multi-scale-loss
ssmmnn11 8f1ca0c
fix for channel sharding
ssmmnn11 d61ea5f
Merge branch 'fix-channel-sharding' into feat/kcrps-multi-scale-loss
ssmmnn11 1497ddc
multi-scale loss improvements
ssmmnn11 fa04281
Merge remote-tracking branch 'origin/main' into feat/kcrps-multi-scal…
ssmmnn11 765b53e
pydantic and some documentation
ssmmnn11 9c33ce9
docu update
ssmmnn11 0417e77
fix
ssmmnn11 652507d
Merge branch 'main' into feat/kcrps-multi-scale-loss
ssmmnn11 7bf3e8f
merged main
ssmmnn11 ad42367
fix for single GPU training
ssmmnn11 0f64a73
fix
ssmmnn11 35f04fe
fix: skip sharding when running on single gpu
theissenhelen 27ff0c5
refactor: factor truncation operations out of model
theissenhelen 83e17bb
more refactoring
theissenhelen e2a31f6
WIP
theissenhelen 4e04123
instantiation of multiscale working
theissenhelen 47ccbcc
MultiscaleLoss working
theissenhelen 6e3725c
WIP
theissenhelen f1dcdfd
use kwargs for multiscale
theissenhelen e8aa6f5
Schema for multiscale loss
theissenhelen 5634add
add multiscale to configs
theissenhelen cf63ffa
Merge remote-tracking branch 'origin/main' into feat/kcrps-multi-scal…
theissenhelen 10b92d4
fix: mscale loss (#661)
ssmmnn11 e79cf27
remove unused entries
theissenhelen bd2c919
Merge remote-tracking branch 'origin/main' into feat/kcrps-multi-scal…
theissenhelen 8a907ab
fix mloss accum missing
theissenhelen 2c7cbc2
add truncation to integration tests
theissenhelen f44037a
adjust weights
theissenhelen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| # (C) Copyright 2024 Anemoi contributors. | ||
| # | ||
| # This software is licensed under the terms of the Apache Licence Version 2.0 | ||
| # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
| # | ||
| # In applying this licence, ECMWF does not waive the privileges and immunities | ||
| # granted to it by virtue of its status as an intergovernmental organisation | ||
| # nor does it submit to any jurisdiction. | ||
|
|
||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
|
|
||
| def make_truncation_matrix(A, data_type=torch.float32): | ||
| A_ = torch.sparse_coo_tensor( | ||
| torch.tensor(np.vstack(A.nonzero()), dtype=torch.long), | ||
| torch.tensor(A.data, dtype=data_type), | ||
| size=A.shape, | ||
| ).coalesce() | ||
| return A_ | ||
|
|
||
|
|
||
| def truncate_fields(x, A, batch_size=None, auto_cast=False): | ||
| if not batch_size: | ||
| batch_size = x.shape[0] | ||
| out = [] | ||
| with torch.amp.autocast(device_type="cuda", enabled=auto_cast): | ||
| for i in range(batch_size): | ||
| out.append(multiply_sparse(x[i, ...], A)) | ||
| return torch.stack(out) | ||
|
|
||
|
|
||
| def multiply_sparse(x, A): | ||
| if torch.cuda.is_available(): | ||
| with torch.amp.autocast(device_type="cuda", enabled=False): | ||
| out = torch.sparse.mm(A, x) | ||
| else: | ||
| with torch.amp.autocast(device_type="cpu", enabled=False): | ||
| out = torch.sparse.mm(A, x) | ||
| return out | ||
|
|
||
|
|
||
| def interpolate_batch(batch: torch.Tensor, intp_matrix: torch.Tensor) -> torch.Tensor: | ||
| input_shape = batch.shape # e.g. (batch steps ensemble grid vars) or (batch steps grid vars) | ||
| batch = batch.reshape(-1, *input_shape[-2:]) | ||
| batch = truncate_fields(batch, intp_matrix) # to coarse resolution | ||
| return batch.reshape(*input_shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.