-
Notifications
You must be signed in to change notification settings - Fork 68
feat(models,training): multiple datasets (WIP) #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 177 commits
Commits
Show all changes
187 commits
Select commit
Hold shift + click to select a range
714506b
data-handlers wip
floriankrb 310a5a0
data-handlers wip
floriankrb 4d36e16
add test
JPXKQX 17a11eb
Merge branch 'main' into refactor/multiple-datasets
JPXKQX 9624ba2
WIP EncProcDec dictionairy batch input
havardhhaugen cdde2cb
group tests into tests/refactor
JPXKQX 015b417
First data handler version
JPXKQX 42213b6
Merge branch 'main' into refactor/multiple-datasets
JPXKQX 023fb21
WIP loss function / forecaster multiple datasets
havardhhaugen 9a5c492
Fix training_step only wants batch as input
havardhhaugen 9385e7f
clean dataloader
JPXKQX c4b708d
style
JPXKQX 3bf9966
Merge branch 'refactor/multiple-datasets' of https://github.com/ecmwf…
JPXKQX 404bb53
refactor
JPXKQX 333a9a8
graph update
JPXKQX 5d2f83f
wip
JPXKQX ce3cf12
update config with sample_providers
JPXKQX 6bd02be
wip II
JPXKQX 9e5a01e
data loading works ...
JPXKQX 7198d93
Add Sampler to synchronise reference date
JPXKQX ea09afb
clean train/val/test stages
JPXKQX fb10ba6
Merge branch 'main' into refactor/multiple-datasets
JPXKQX 2aee9bb
keep refactoring with Florian
JPXKQX d5ed404
works: train --config-name=debug_multiple_datasets
JPXKQX 0df0ac0
missing changes
JPXKQX 6176e1a
pre-commit
JPXKQX 4c8a7b3
models: pre-commit
JPXKQX cb24aa4
add draft classes for Tensor/arrays
JPXKQX 396bc70
things, stackedThings, and groupedthings
JPXKQX 54bc4a3
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX a0fe6d5
back to dict[tensors]
JPXKQX 1d3bf4e
reduce num batches in debug config
VeraChristina a3a3eae
basic dict loss,works: train --config-name=debug_multiple_datasets
VeraChristina ada3c3b
update normalizer to use datahandler information
VeraChristina 1d74b87
remove additional scalars from DictLoss
VeraChristina e3f0b0c
Merge branch 'main' into refactor/multiple-datasets
JPXKQX 1e32df0
Merge branch 'refactor/multiple-datasets-dictloss' into refactor/mult…
VeraChristina 80b75a5
works: anemoi-training train --config-name=debug_multiple_datasets
VeraChristina ecc02a4
remove non-selected variables from processors in SelectedDataHandler
VeraChristina a70d4b4
statistics property for RecordProvider
JPXKQX 7c3bcc3
dict preprocessors
floriankrb 6500a92
fix
floriankrb 0549324
works:
floriankrb 418600a
Merge branch 'main' into refactor/multiple-datasets
JPXKQX c3f45e3
provide base configs
JPXKQX 8b3ed1e
fix: data handlers config
JPXKQX 37cc932
minor fixes
JPXKQX a13e20f
update
JPXKQX a8c9765
temporary
JPXKQX bffd874
expand config
JPXKQX 38d492e
using one type for DataHandlers. does not run.
floriankrb 14834c0
renaming groups
JPXKQX cb566e0
add grouped and nongrouped data handlers
JPXKQX 4cc2eb0
update data handlers
JPXKQX 785ee38
small refactor
JPXKQX 2dd4abc
fix config typo
JPXKQX ffd278d
style
JPXKQX cc786ab
move to timedelta
JPXKQX a46af98
minor fix
JPXKQX d551e1c
use __getitem__ notation
JPXKQX 1991aa2
blank spaces
JPXKQX 98a03ff
datetime as _getitem_ arg for sample and record providers
JPXKQX 8c9dc91
index as int instead of np.int
JPXKQX cd67840
implentation idea for sample provider. tests do not pass
floriankrb e15eb86
up
floriankrb d964b61
draft
floriankrb c7cb41b
user-friendly config
JPXKQX 8ee288b
mapping from user-friendly yaml to config dict
JPXKQX bd59aef
dop draft
floriankrb 02625b8
hard-coded fix to avoid using data at [-1]
floriankrb 829560e
up
floriankrb 3d9cac6
up
floriankrb 052e1e1
dop now importing from anemoi-training.
floriankrb b096124
dop script runnning
floriankrb 63f40e7
dop script runnning
floriankrb 168b4f0
Merge branch 'refactor/multiple-datasets-2' into refactor/multiple-da…
floriankrb 6811306
simplify
floriankrb 6b133d1
simplify
floriankrb 17e6871
up
floriankrb d0f4497
add shuffle
floriankrb 62622e3
renamed to SampleProvider
floriankrb 9382d67
refactoring to new structure
JPXKQX 3237b5f
how about we add in latitudes, longitudes and timedeltas as a dict
floriankrb 988aee4
update the config
floriankrb 0db7148
give seconds (int64) because pytorch does not like timedeltas
floriankrb 65930b4
add "set_group" to data_handlers config
JPXKQX d9f20d4
test_data_loading works
JPXKQX ce85715
test_data_loading works (with era5 variable selection)
JPXKQX e20d443
update with sample_provider.latitude(i)
floriankrb 9be4b06
implement sample_provider.latitude(i)
floriankrb fb23966
test_data_loading works (it returns dict[dict[list[arrays]]])
JPXKQX 52291dd
introducing processors
JPXKQX e3d3eb3
renamed "groups" into "dictionary"
floriankrb 90e75ab
use utils to convert to frequency
floriankrb feae96a
remove non-sense every-thing-as-a-dict
floriankrb 753ead9
stack tensors with the "tensor" keyword
floriankrb 3f79330
Breaking change in the config. "tensor" keyword. changing config
floriankrb 756f107
keywork STEPS disappears. adding "tuple"
floriankrb e2d4f86
update dop
floriankrb 1b3ff34
processors working
JPXKQX 7f29dfd
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX 77a4e40
add shapes
JPXKQX abb1c28
added length of a sample provider
floriankrb c463391
added "timdeltas" shortcut
floriankrb fd01710
clean up
floriankrb c5913e6
clean up
floriankrb 47169ed
moving processors
floriankrb 4f9c9a4
rename sample_factory to sample_provider_factory
floriankrb 46064df
up
floriankrb e4c1f2b
shuffling samples
floriankrb 033f2c8
fix: processors
JPXKQX f696a90
fix
floriankrb 6440da5
clean
floriankrb 4accfa3
type hints
JPXKQX 7ab71a3
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX 8ec2720
include num_channels
JPXKQX 45eb8b7
include downscaling config
JPXKQX d29d675
update
JPXKQX 1a43e86
update num_channels
JPXKQX 8e802e2
training
JPXKQX 0c31f2d
more logs
floriankrb f2f4135
remove old code
floriankrb e6ae018
training sources of data
floriankrb 9cea226
black
floriankrb ad85d24
update config
floriankrb 360e684
renamed data_config into sources
floriankrb a7e7277
fix: configs & test
JPXKQX 2ef2c5c
new configs: downscaling
JPXKQX 0ec6b3d
new configs: downscaling
JPXKQX 9af5df2
fixing workflow
JPXKQX 41e0165
refactor with breaking changes in the config.
floriankrb 7cd8029
feat: updating downscaling workflow
JPXKQX 1fab855
freq
JPXKQX 784027a
update configs
JPXKQX b883778
include graph
JPXKQX 62ba9f0
clean
floriankrb 1bae774
top level request gets priority
floriankrb f08f417
added configs and configs.* in request
floriankrb 543cac7
qa
floriankrb eccfcbd
update models
JPXKQX 16b4693
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX 0677d06
more logs
floriankrb aa5769f
more updates
JPXKQX c03493e
allow sub group in variables list
floriankrb 805876a
minor
JPXKQX e9c869b
pre-commit
JPXKQX 256fd7e
move name_to_index as an attribute of the sample provider. wip
floriankrb 98f8f17
re-add name_to_index in request
floriankrb e6f338e
fix lenght
floriankrb d136f94
added statistics
floriankrb 0288d6e
refactor and clean
floriankrb 37a160c
up
floriankrb c397965
fix
floriankrb 58ab546
bringing dynamic mappers
JPXKQX 3063942
the sample provider provides now always a dict
floriankrb bfd7144
apply function
floriankrb 6c26add
remove request from datamodule
JPXKQX ef6a57f
cleanup
floriankrb d8f73b0
added get_obj
floriankrb d956105
added get_native
floriankrb ac280c8
use structure
JPXKQX 874c265
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX 83bbbdb
WIP
mishooax b57012f
Merge branch 'refactor/multiple-datasets-3' of github.com:ecmwf/anemo…
mishooax 342cb37
add function structure
floriankrb 0870c5e
clean
floriankrb 77a7c64
inputer --> imputer
JPXKQX de43262
fix: factory
JPXKQX ef96dfa
works: downscaling
JPXKQX 1356eb8
fix: dims in loss
JPXKQX db50a2c
add apply function as attriibute to Structure
JPXKQX 53ebd3c
suppor regional models
JPXKQX 9ddb758
Merge branch 'main' into refactor/multiple-datasets-4
JPXKQX 4379086
feat: working version (downscaling & multiple datasets)
JPXKQX 5043048
pre-commit
JPXKQX 8938c61
configs
JPXKQX 411b17b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 192c664
remove unused
JPXKQX 7254f9b
style
JPXKQX c1d7237
Merge branch 'main' into refactor/multiple-datasets-4
JPXKQX e1e7a26
feat: autoencoder configs
JPXKQX 328f87a
Merge branch 'main' into refactor/multiple-datasets-4
JPXKQX e343ba3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4e114f7
pre-commit
JPXKQX eebaf96
Update training/src/anemoi/training/data/refactor/datamodule.py
JPXKQX 6d4953b
Merge branch 'refactor/multiple-datasets-4' of https://github.com/ecm…
JPXKQX 60dffed
imports
JPXKQX File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,220 @@ | ||
| import json | ||
| import os | ||
| import random | ||
| from typing import Optional | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import yaml | ||
| from torch.utils.data import IterableDataset | ||
| from torch.utils.data import get_worker_info | ||
|
|
||
| CONFIG_YAML = """ | ||
| sources: | ||
| training: | ||
| # era5: | ||
| # dataset: | ||
| # dataset: aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8 | ||
| # set_group: era5 | ||
| # # preprocessors: | ||
| # # tp: | ||
| # # - normalizer: mean-std | ||
| snow: | ||
| dataset: observations-testing-2018-2018-6h-v1-one-month | ||
| metop_a: | ||
| dataset: observations-testing-2018-2018-6h-v1-one-month | ||
| amsr2_h180: | ||
| dataset: observations-testing-2018-2018-6h-v1-one-month | ||
| validation: | ||
| todo: | ||
| sample: | ||
| dictionary: | ||
| input: | ||
| dictionary: | ||
| ascat_metop_a: | ||
| tuple: | ||
| - timedelta: "-6h" | ||
| variables: | ||
| metop_a: ["scatss_1", "scatss_2"] | ||
| snow: | ||
| tuple: | ||
| - timedelta: "0h" | ||
| variables: | ||
| snow: ["sdepth_0"] | ||
| amsr2: | ||
| tuple: | ||
| - timedelta: "-6h" | ||
| variables: | ||
| amsr2_h180: ["rawbt_1", "rawbt_2", "rawbt_3", "rawbt_4"] | ||
| """ | ||
|
|
||
| CONFIG = yaml.safe_load(CONFIG_YAML) | ||
|
|
||
|
|
||
| from anemoi.training.data.refactor.draft import sample_provider_factory | ||
|
|
||
|
|
||
| def show_yaml(structure): | ||
| return yaml.dump(structure, indent=2, sort_keys=False) | ||
|
|
||
|
|
||
| def show_json(structure): | ||
| return json.dumps(structure, indent=2, default=shorten_numpy) | ||
|
|
||
|
|
||
| def shorten_numpy(structure): | ||
| if isinstance(structure, np.ndarray): | ||
| return f"np.array({structure.shape})" | ||
| return structure | ||
|
|
||
|
|
||
| def get_base_seed(): | ||
| """Get a base seed for random number generation. | ||
| This is a placeholder function; replace with actual logic to get a base seed. | ||
| """ | ||
| return 42 # Example fixed seed, replace with actual logic as needed | ||
|
|
||
|
|
||
| class DOPDataset(IterableDataset): | ||
| def __init__( | ||
| self, | ||
| # config: dict, | ||
| shuffle: bool = True, | ||
| rollout: int = 1, | ||
| multistep: int = 1, | ||
| task: str = "training", | ||
| ) -> None: | ||
|
|
||
| self.shuffle = shuffle | ||
| # self.config = config | ||
| self.rollout = rollout | ||
| self.multistep = multistep | ||
| self.task = task | ||
|
|
||
| # lazy init | ||
| self.n_samples_per_epoch_total: int = 0 | ||
| self.n_samples_per_epoch_per_worker: int = 0 | ||
|
|
||
| # additional state vars (lazy init) | ||
| self.n_samples_per_worker = 0 | ||
| self.chunk_index_range: Optional[np.ndarray] = None | ||
| self.shuffle = shuffle | ||
| self.rng: Optional[np.random.Generator] = None | ||
| self.worker_id: int = -1 | ||
|
|
||
| # "full" shuffling | ||
| self.data_indices: Optional[np.ndarray] = None | ||
|
|
||
| self.seed_comm_group_id = 0 | ||
| self.seed_comm_num_groups = 1 | ||
|
|
||
| training_context = { | ||
| "name": "training", | ||
| "sources": CONFIG["sources"]["training"], | ||
| "start": "2018-11-02", | ||
| "end": "2018-11-01", | ||
| } | ||
|
|
||
| self._sample_provider = sample_provider_factory(context=training_context, **CONFIG["sample"]) | ||
| self._sample_provider = self._sample_provider.shuffle(seed=42) | ||
|
|
||
| # self.len = len(self._sample_provider) | ||
|
|
||
| def __get_sample(self, index: int): | ||
| """Get a sample from the dataset.""" | ||
| return self._sample_provider[index] | ||
|
|
||
| def per_worker_init(self, n_workers: int, worker_id: int) -> None: | ||
| """Called by worker_init_func on each copy of dataset. | ||
|
|
||
| This initialises after the worker process has been spawned. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n_workers : int | ||
| Number of workers | ||
| worker_id : int | ||
| Worker ID | ||
| """ | ||
| self.worker_id = worker_id | ||
|
|
||
| lenght = len(self._sample_provider) | ||
| # Divide this equally across shards (one shard per group!) | ||
| shard_size = lenght // self.seed_comm_num_groups | ||
| shard_start = self.seed_comm_group_id * shard_size | ||
| shard_end = min((self.seed_comm_group_id + 1) * shard_size, lenght) | ||
|
|
||
| shard_len = shard_end - shard_start | ||
| self.n_samples_per_worker = shard_len // n_workers | ||
|
|
||
| low = shard_start + worker_id * self.n_samples_per_worker | ||
| high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) | ||
| self.chunk_index_range = np.arange(low, high, dtype=np.uint32) | ||
|
|
||
| seed = get_base_seed() # all workers get the same seed (so they all get the same index shuffle) | ||
| torch.manual_seed(seed) | ||
| random.seed(seed) | ||
| self.rng = np.random.default_rng(seed=seed) | ||
| sanity_rnd = self.rng.random(1) | ||
| print("Sanity check random number:", sanity_rnd) | ||
|
|
||
| def __iter__(self): | ||
| # no shuffle, just iterate over the chunk indices | ||
| for idx in self.chunk_index_range: | ||
| print( | ||
| f"VALIDATION: Worker {self.worker_id} (pid {os.getpid()}) fetching sample index {idx} ...", | ||
| ) | ||
| yield self.__get_sample(idx) | ||
|
|
||
|
|
||
| def worker_init_func(worker_id: int) -> None: | ||
| """Configures each dataset worker process. | ||
|
|
||
| Calls WeatherBenchDataset.per_worker_init() on each dataset object. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| worker_id : int | ||
| Worker ID | ||
|
|
||
| Raises | ||
| ------ | ||
| RuntimeError | ||
| If worker_info is None | ||
| """ | ||
| worker_info = get_worker_info() # information specific to each worker process | ||
| if worker_info is None: | ||
| print("worker_info is None! Set num_workers > 0 in your dataloader!") | ||
| raise RuntimeError | ||
| dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process. | ||
| dataset_obj.per_worker_init( | ||
| n_workers=worker_info.num_workers, | ||
| worker_id=worker_id, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| ds = DOPDataset( | ||
| # CONFIG, | ||
| shuffle=False, | ||
| rollout=1, | ||
| multistep=1, | ||
| task="training", | ||
| ) | ||
|
|
||
| loader_params = { | ||
| "batch_size": 1, # must be 1 for the time being | ||
| "batch_sampler": None, | ||
| "num_workers": 2, | ||
| "pin_memory": False, | ||
| "worker_init_fn": worker_init_func, | ||
| # "collate_fn": None, # collator_wrapper(return_original_metadata=cfg_.dataloader.return_dates), | ||
| } | ||
|
|
||
| dl = torch.utils.data.DataLoader(ds, **loader_params, sampler=None) | ||
|
|
||
| for batch_idx, batch in enumerate(dl): | ||
| print("%s", batch) | ||
| if batch_idx >= 1: | ||
| break | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.