-
Notifications
You must be signed in to change notification settings - Fork 68
fix(training): scaler memory usage #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
HCookie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A great refactor, thanks
|
@japols, do you have some runs/plot that you can post before and after the fix to confirm it's working as expected? |
Is the bottom line that this was not caught in the benchmarking: the memory usage was very small at n320? |
ssmmnn11
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good!
## Description Avoid explicitly materialising a single, big combined scaling tensor in memory. Instead, apply scalers iteratively via broadcasted multiplication. ## What problem does this change solve? Combining all scalers in `ScaleTensor.get_scaler()` can lead to big allocations, up to 17GB when training at 4km resolution. ## Additional notes ## Memory snapshot of loss (fwd/bwd) when training at 9km before (left) and after (right) changes:  Small test run of old version vs new: [mlflow](https://mlflow.ecmwf.int/#/metric?runs=[%22e7ddf2fe507f4d68a26ae0406d6e5b8f%22,%22410a0717832c4782a0d9b472d82f48c5%22]&metric=%22train_mse_loss_step%22&experiments=[%2245%22]&plot_metric_keys=%5B%22train_mse_loss_step%22%5D&plot_layout={%22autosize%22:true,%22xaxis%22:{},%22yaxis%22:{}}&x_axis=relative&y_axis_scale=linear&line_smoothness=1&show_point=false&deselected_curves=[]&last_linear_y_axis_range=[]) Tagging @HCookie @sahahner ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
## Description Avoid explicitly materialising a single, big combined scaling tensor in memory. Instead, apply scalers iteratively via broadcasted multiplication. ## What problem does this change solve? Combining all scalers in `ScaleTensor.get_scaler()` can lead to big allocations, up to 17GB when training at 4km resolution. ## Additional notes ## Memory snapshot of loss (fwd/bwd) when training at 9km before (left) and after (right) changes:  Small test run of old version vs new: [mlflow](https://mlflow.ecmwf.int/#/metric?runs=[%22e7ddf2fe507f4d68a26ae0406d6e5b8f%22,%22410a0717832c4782a0d9b472d82f48c5%22]&metric=%22train_mse_loss_step%22&experiments=[%2245%22]&plot_metric_keys=%5B%22train_mse_loss_step%22%5D&plot_layout={%22autosize%22:true,%22xaxis%22:{},%22yaxis%22:{}}&x_axis=relative&y_axis_scale=linear&line_smoothness=1&show_point=false&deselected_curves=[]&last_linear_y_axis_range=[]) Tagging @HCookie @sahahner ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
🤖 Automated Release PR This PR was created by `release-please` to prepare the next release. Once merged: 1. A new version tag will be created 2. A GitHub release will be published 3. The changelog will be updated Changes to be included in the next release: --- <details><summary>training: 0.6.0</summary> ## [0.6.0](training-0.5.1...training-0.6.0) (2025-08-01) ### ⚠ BREAKING CHANGES * for schemas of data processors ([#433](#433)) * BaseGraphModule and tasks introduced in anemoi-core ([#399](#399)) ### Features * Add metadata back to pl checkpoint. ([#303](#303)) ([0193b28](0193b28)) * BaseGraphModule and tasks introduced in anemoi-core ([#399](#399)) ([f8ab962](f8ab962)) * **deps:** Use mlflow-skinny instead of mlflow ([#418](#418)) ([6a8beb3](6a8beb3)) * Log FTT2 loss + Fourier Correlation loss ([#148](#148)) ([345b0ab](345b0ab)) * **model:** Postprocessors for leaky boundings ([#315](#315)) ([b54562b](b54562b)) * **models:** Checkpointed Mapper Chunking ([#406](#406)) ([8577772](8577772)) * **models:** Mapper edge sharding ([#366](#366)) ([326751d](326751d)) * Variable filtering ([#208](#208)) ([fba5e47](fba5e47)) ### Bug Fixes * Dropping 3.9 ([#436](#436)) ([f6c0214](f6c0214)) * For schemas of data processors ([#433](#433)) ([539939b](539939b)) * Mlflow hp params limit ([#424](#424)) ([138bc3a](138bc3a)) * Mlflowlogger duplicated key ([#414](#414)) ([cb64a1c](cb64a1c)) * **models,traininig:** Hierarchical model + integration test ([#400](#400)) ([71dfd89](71dfd89)) * **models:** Add removed sharded_input_key in PR400 ([#425](#425)) ([089fe6f](089fe6f)) * New checkpoint ([#445](#445)) ([a25df93](a25df93)) * Plotting error when precip related params are not diagnostic ([#369](#369)) ([010cfa3](010cfa3)) * **training:** Address issues with [#208](#208) ([#417](#417)) ([665f462](665f462)) * **training:** Scaler memory usage ([#391](#391)) ([a9d30e1](a9d30e1)) * Update import mflow utils unit tests ([#427](#427)) ([70ecdd9](70ecdd9)) * Update level retrieval logic ([#405](#405)) ([f393bc3](f393bc3)) * Use transforms: Variable for ExtractVariableGroupAndLevel ([#321](#321)) ([7649f4f](7649f4f)) * Warm restart ([#443](#443)) ([ff96236](ff96236)) ### Documentation * **graphs:** Documenting some missing features ([#423](#423)) ([8addbd8](8addbd8)) </details> <details><summary>graphs: 0.6.3</summary> ## [0.6.3](graphs-0.6.2...graphs-0.6.3) (2025-08-01) ### Features * **graphs:** Add lat weighted attribute ([#223](#223)) ([5dd32ca](5dd32ca)) * **graphs:** Support to export edges to npz ([#395](#395)) ([e21738f](e21738f)) ### Bug Fixes * Dropping 3.9 ([#436](#436)) ([f6c0214](f6c0214)) * **graphs:** Revert PR [#379](#379) ([#409](#409)) ([d51219f](d51219f)) * **graphs:** Throw error instead of raising warning when graph exists. ([#379](#379)) ([6ec6c18](6ec6c18)) * **graphs:** Undo masking when torch-cluster is installed ([#375](#375)) ([9f75c06](9f75c06)) ### Documentation * **graphs:** Documenting some missing features ([#423](#423)) ([8addbd8](8addbd8)) </details> <details><summary>models: 0.9.0</summary> ## [0.9.0](models-0.8.1...models-0.9.0) (2025-08-01) ### ⚠ BREAKING CHANGES * for schemas of data processors ([#433](#433)) ### Features * **model:** Postprocessors for leaky boundings ([#315](#315)) ([b54562b](b54562b)) * **models:** Checkpointed Mapper Chunking ([#406](#406)) ([8577772](8577772)) * **models:** Mapper edge sharding ([#366](#366)) ([326751d](326751d)) ### Bug Fixes * Dropping 3.9 ([#436](#436)) ([f6c0214](f6c0214)) * For schemas of data processors ([#433](#433)) ([539939b](539939b)) * **models,traininig:** Hierarchical model + integration test ([#400](#400)) ([71dfd89](71dfd89)) * **models:** Remove repeated lines ([#377](#377)) ([1f0b861](1f0b861)) * **models:** Uneven channel sharding ([#385](#385)) ([dd095c4](dd095c4)) * Pydantic model validator not working in transformer schema ([#422](#422)) ([42f437a](42f437a)) * Remove dead code and fix typo ([#357](#357)) ([8c615ba](8c615ba)) </details> --- > [!IMPORTANT] > Please do not change the PR title, manifest file, or any other automatically generated content in this PR unless you understand the implications. Changes here can break the release process. > >⚠️ Merging this PR will: > - Create a new release > - Trigger deployment pipelines > - Update package versions **Before merging:** - Ensure all tests pass - Review the changelog carefully - Get required approvals [Release-please documentation](https://github.com/googleapis/release-please)
Description
Avoid explicitly materialising a single, big combined scaling tensor in memory. Instead, apply scalers iteratively via broadcasted multiplication.
What problem does this change solve?
Combining all scalers in
ScaleTensor.get_scaler()can lead to big allocations, up to 17GB when training at 4km resolution.Additional notes
Memory snapshot of loss (fwd/bwd) when training at 9km before (left) and after (right) changes:
Small test run of old version vs new: mlflow
Tagging @HCookie @sahahner
As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/
By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.