Skip to content
Closed
Show file tree
Hide file tree
Changes from 177 commits
Commits
Show all changes
187 commits
Select commit Hold shift + click to select a range
714506b
data-handlers wip
floriankrb Apr 10, 2025
310a5a0
data-handlers wip
floriankrb Apr 10, 2025
4d36e16
add test
JPXKQX Apr 11, 2025
17a11eb
Merge branch 'main' into refactor/multiple-datasets
JPXKQX Apr 23, 2025
9624ba2
WIP EncProcDec dictionairy batch input
havardhhaugen Apr 23, 2025
cdde2cb
group tests into tests/refactor
JPXKQX Apr 24, 2025
015b417
First data handler version
JPXKQX Apr 25, 2025
42213b6
Merge branch 'main' into refactor/multiple-datasets
JPXKQX Apr 25, 2025
023fb21
WIP loss function / forecaster multiple datasets
havardhhaugen Apr 25, 2025
9a5c492
Fix training_step only wants batch as input
havardhhaugen Apr 25, 2025
9385e7f
clean dataloader
JPXKQX Apr 25, 2025
c4b708d
style
JPXKQX Apr 25, 2025
3bf9966
Merge branch 'refactor/multiple-datasets' of https://github.com/ecmwf…
JPXKQX Apr 25, 2025
404bb53
refactor
JPXKQX Apr 25, 2025
333a9a8
graph update
JPXKQX Apr 25, 2025
5d2f83f
wip
JPXKQX May 1, 2025
ce3cf12
update config with sample_providers
JPXKQX May 1, 2025
6bd02be
wip II
JPXKQX May 2, 2025
9e5a01e
data loading works ...
JPXKQX May 2, 2025
7198d93
Add Sampler to synchronise reference date
JPXKQX May 7, 2025
ea09afb
clean train/val/test stages
JPXKQX May 7, 2025
fb10ba6
Merge branch 'main' into refactor/multiple-datasets
JPXKQX May 7, 2025
2aee9bb
keep refactoring with Florian
JPXKQX May 8, 2025
d5ed404
works: train --config-name=debug_multiple_datasets
JPXKQX May 9, 2025
0df0ac0
missing changes
JPXKQX May 9, 2025
6176e1a
pre-commit
JPXKQX May 9, 2025
4c8a7b3
models: pre-commit
JPXKQX May 9, 2025
cb24aa4
add draft classes for Tensor/arrays
JPXKQX May 14, 2025
396bc70
things, stackedThings, and groupedthings
JPXKQX May 15, 2025
54bc4a3
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX May 16, 2025
a0fe6d5
back to dict[tensors]
JPXKQX May 27, 2025
1d3bf4e
reduce num batches in debug config
VeraChristina May 27, 2025
a3a3eae
basic dict loss,works: train --config-name=debug_multiple_datasets
VeraChristina May 27, 2025
ada3c3b
update normalizer to use datahandler information
VeraChristina May 29, 2025
1d74b87
remove additional scalars from DictLoss
VeraChristina May 29, 2025
e3f0b0c
Merge branch 'main' into refactor/multiple-datasets
JPXKQX May 29, 2025
1e32df0
Merge branch 'refactor/multiple-datasets-dictloss' into refactor/mult…
VeraChristina May 29, 2025
80b75a5
works: anemoi-training train --config-name=debug_multiple_datasets
VeraChristina May 29, 2025
ecc02a4
remove non-selected variables from processors in SelectedDataHandler
VeraChristina May 30, 2025
a70d4b4
statistics property for RecordProvider
JPXKQX Jun 2, 2025
7c3bcc3
dict preprocessors
floriankrb Jun 3, 2025
6500a92
fix
floriankrb Jun 3, 2025
0549324
works:
floriankrb Jun 3, 2025
418600a
Merge branch 'main' into refactor/multiple-datasets
JPXKQX Jun 3, 2025
c3f45e3
provide base configs
JPXKQX Jun 3, 2025
8b3ed1e
fix: data handlers config
JPXKQX Jun 4, 2025
37cc932
minor fixes
JPXKQX Jun 4, 2025
a13e20f
update
JPXKQX Jun 4, 2025
a8c9765
temporary
JPXKQX Jun 5, 2025
bffd874
expand config
JPXKQX Jun 5, 2025
38d492e
using one type for DataHandlers. does not run.
floriankrb Jun 10, 2025
14834c0
renaming groups
JPXKQX Jun 12, 2025
cb566e0
add grouped and nongrouped data handlers
JPXKQX Jun 12, 2025
4cc2eb0
update data handlers
JPXKQX Jun 18, 2025
785ee38
small refactor
JPXKQX Jun 18, 2025
2dd4abc
fix config typo
JPXKQX Jun 18, 2025
ffd278d
style
JPXKQX Jun 18, 2025
cc786ab
move to timedelta
JPXKQX Jun 18, 2025
a46af98
minor fix
JPXKQX Jun 18, 2025
d551e1c
use __getitem__ notation
JPXKQX Jun 19, 2025
1991aa2
blank spaces
JPXKQX Jun 19, 2025
98a03ff
datetime as _getitem_ arg for sample and record providers
JPXKQX Jun 19, 2025
8c9dc91
index as int instead of np.int
JPXKQX Jun 20, 2025
cd67840
implentation idea for sample provider. tests do not pass
floriankrb Jun 24, 2025
e15eb86
up
floriankrb Jun 24, 2025
d964b61
draft
floriankrb Jun 30, 2025
c7cb41b
user-friendly config
JPXKQX Jul 2, 2025
8ee288b
mapping from user-friendly yaml to config dict
JPXKQX Jul 2, 2025
bd59aef
dop draft
floriankrb Jul 2, 2025
02625b8
hard-coded fix to avoid using data at [-1]
floriankrb Jul 2, 2025
829560e
up
floriankrb Jul 2, 2025
3d9cac6
up
floriankrb Jul 2, 2025
052e1e1
dop now importing from anemoi-training.
floriankrb Jul 2, 2025
b096124
dop script runnning
floriankrb Jul 2, 2025
63f40e7
dop script runnning
floriankrb Jul 2, 2025
168b4f0
Merge branch 'refactor/multiple-datasets-2' into refactor/multiple-da…
floriankrb Jul 2, 2025
6811306
simplify
floriankrb Jul 2, 2025
6b133d1
simplify
floriankrb Jul 2, 2025
17e6871
up
floriankrb Jul 2, 2025
d0f4497
add shuffle
floriankrb Jul 2, 2025
62622e3
renamed to SampleProvider
floriankrb Jul 2, 2025
9382d67
refactoring to new structure
JPXKQX Jul 2, 2025
3237b5f
how about we add in latitudes, longitudes and timedeltas as a dict
floriankrb Jul 2, 2025
988aee4
update the config
floriankrb Jul 2, 2025
0db7148
give seconds (int64) because pytorch does not like timedeltas
floriankrb Jul 2, 2025
65930b4
add "set_group" to data_handlers config
JPXKQX Jul 3, 2025
d9f20d4
test_data_loading works
JPXKQX Jul 3, 2025
ce85715
test_data_loading works (with era5 variable selection)
JPXKQX Jul 3, 2025
e20d443
update with sample_provider.latitude(i)
floriankrb Jul 3, 2025
9be4b06
implement sample_provider.latitude(i)
floriankrb Jul 3, 2025
fb23966
test_data_loading works (it returns dict[dict[list[arrays]]])
JPXKQX Jul 3, 2025
52291dd
introducing processors
JPXKQX Jul 3, 2025
e3d3eb3
renamed "groups" into "dictionary"
floriankrb Jul 3, 2025
90e75ab
use utils to convert to frequency
floriankrb Jul 4, 2025
feae96a
remove non-sense every-thing-as-a-dict
floriankrb Jul 4, 2025
753ead9
stack tensors with the "tensor" keyword
floriankrb Jul 4, 2025
3f79330
Breaking change in the config. "tensor" keyword. changing config
floriankrb Jul 4, 2025
756f107
keywork STEPS disappears. adding "tuple"
floriankrb Jul 4, 2025
e2d4f86
update dop
floriankrb Jul 4, 2025
1b3ff34
processors working
JPXKQX Jul 4, 2025
7f29dfd
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX Jul 4, 2025
77a4e40
add shapes
JPXKQX Jul 4, 2025
abb1c28
added length of a sample provider
floriankrb Jul 4, 2025
c463391
added "timdeltas" shortcut
floriankrb Jul 7, 2025
fd01710
clean up
floriankrb Jul 7, 2025
c5913e6
clean up
floriankrb Jul 7, 2025
47169ed
moving processors
floriankrb Jul 7, 2025
4f9c9a4
rename sample_factory to sample_provider_factory
floriankrb Jul 7, 2025
46064df
up
floriankrb Jul 7, 2025
e4c1f2b
shuffling samples
floriankrb Jul 7, 2025
033f2c8
fix: processors
JPXKQX Jul 7, 2025
f696a90
fix
floriankrb Jul 7, 2025
6440da5
clean
floriankrb Jul 7, 2025
4accfa3
type hints
JPXKQX Jul 7, 2025
7ab71a3
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX Jul 7, 2025
8ec2720
include num_channels
JPXKQX Jul 7, 2025
45eb8b7
include downscaling config
JPXKQX Jul 7, 2025
d29d675
update
JPXKQX Jul 8, 2025
1a43e86
update num_channels
JPXKQX Jul 8, 2025
8e802e2
training
JPXKQX Jul 8, 2025
0c31f2d
more logs
floriankrb Jul 9, 2025
f2f4135
remove old code
floriankrb Jul 9, 2025
e6ae018
training sources of data
floriankrb Jul 9, 2025
9cea226
black
floriankrb Jul 9, 2025
ad85d24
update config
floriankrb Jul 9, 2025
360e684
renamed data_config into sources
floriankrb Jul 9, 2025
a7e7277
fix: configs & test
JPXKQX Jul 14, 2025
2ef2c5c
new configs: downscaling
JPXKQX Jul 14, 2025
0ec6b3d
new configs: downscaling
JPXKQX Jul 14, 2025
9af5df2
fixing workflow
JPXKQX Jul 14, 2025
41e0165
refactor with breaking changes in the config.
floriankrb Jul 11, 2025
7cd8029
feat: updating downscaling workflow
JPXKQX Jul 15, 2025
1fab855
freq
JPXKQX Jul 15, 2025
784027a
update configs
JPXKQX Jul 15, 2025
b883778
include graph
JPXKQX Jul 15, 2025
62ba9f0
clean
floriankrb Jul 15, 2025
1bae774
top level request gets priority
floriankrb Jul 15, 2025
f08f417
added configs and configs.* in request
floriankrb Jul 15, 2025
543cac7
qa
floriankrb Jul 15, 2025
eccfcbd
update models
JPXKQX Jul 15, 2025
16b4693
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX Jul 15, 2025
0677d06
more logs
floriankrb Jul 15, 2025
aa5769f
more updates
JPXKQX Jul 15, 2025
c03493e
allow sub group in variables list
floriankrb Jul 15, 2025
805876a
minor
JPXKQX Jul 15, 2025
e9c869b
pre-commit
JPXKQX Jul 15, 2025
256fd7e
move name_to_index as an attribute of the sample provider. wip
floriankrb Jul 15, 2025
98f8f17
re-add name_to_index in request
floriankrb Jul 15, 2025
e6f338e
fix lenght
floriankrb Jul 15, 2025
d136f94
added statistics
floriankrb Jul 15, 2025
0288d6e
refactor and clean
floriankrb Jul 15, 2025
37a160c
up
floriankrb Jul 15, 2025
c397965
fix
floriankrb Jul 16, 2025
58ab546
bringing dynamic mappers
JPXKQX Jul 16, 2025
3063942
the sample provider provides now always a dict
floriankrb Jul 16, 2025
bfd7144
apply function
floriankrb Jul 16, 2025
6c26add
remove request from datamodule
JPXKQX Jul 17, 2025
ef6a57f
cleanup
floriankrb Jul 18, 2025
d8f73b0
added get_obj
floriankrb Jul 18, 2025
d956105
added get_native
floriankrb Jul 18, 2025
ac280c8
use structure
JPXKQX Jul 18, 2025
874c265
Merge branch 'refactor/multiple-datasets-3' of https://github.com/ecm…
JPXKQX Jul 18, 2025
83bbbdb
WIP
mishooax Jul 18, 2025
b57012f
Merge branch 'refactor/multiple-datasets-3' of github.com:ecmwf/anemo…
mishooax Jul 18, 2025
342cb37
add function structure
floriankrb Jul 18, 2025
0870c5e
clean
floriankrb Jul 18, 2025
77a7c64
inputer --> imputer
JPXKQX Jul 21, 2025
de43262
fix: factory
JPXKQX Jul 21, 2025
ef96dfa
works: downscaling
JPXKQX Jul 22, 2025
1356eb8
fix: dims in loss
JPXKQX Jul 22, 2025
db50a2c
add apply function as attriibute to Structure
JPXKQX Jul 24, 2025
53ebd3c
suppor regional models
JPXKQX Jul 29, 2025
9ddb758
Merge branch 'main' into refactor/multiple-datasets-4
JPXKQX Jul 29, 2025
4379086
feat: working version (downscaling & multiple datasets)
JPXKQX Jul 29, 2025
5043048
pre-commit
JPXKQX Jul 29, 2025
8938c61
configs
JPXKQX Jul 29, 2025
411b17b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2025
192c664
remove unused
JPXKQX Jul 29, 2025
7254f9b
style
JPXKQX Jul 29, 2025
c1d7237
Merge branch 'main' into refactor/multiple-datasets-4
JPXKQX Jul 29, 2025
e1e7a26
feat: autoencoder configs
JPXKQX Jul 29, 2025
328f87a
Merge branch 'main' into refactor/multiple-datasets-4
JPXKQX Aug 4, 2025
e343ba3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2025
4e114f7
pre-commit
JPXKQX Aug 4, 2025
eebaf96
Update training/src/anemoi/training/data/refactor/datamodule.py
JPXKQX Aug 4, 2025
6d4953b
Merge branch 'refactor/multiple-datasets-4' of https://github.com/ecm…
JPXKQX Aug 4, 2025
60dffed
imports
JPXKQX Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions dop_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import json
import os
import random
from typing import Optional

import numpy as np
import torch
import yaml
from torch.utils.data import IterableDataset
from torch.utils.data import get_worker_info

CONFIG_YAML = """
sources:
training:
# era5:
# dataset:
# dataset: aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8
# set_group: era5
# # preprocessors:
# # tp:
# # - normalizer: mean-std
snow:
dataset: observations-testing-2018-2018-6h-v1-one-month
metop_a:
dataset: observations-testing-2018-2018-6h-v1-one-month
amsr2_h180:
dataset: observations-testing-2018-2018-6h-v1-one-month
validation:
todo:
sample:
dictionary:
input:
dictionary:
ascat_metop_a:
tuple:
- timedelta: "-6h"
variables:
metop_a: ["scatss_1", "scatss_2"]
snow:
tuple:
- timedelta: "0h"
variables:
snow: ["sdepth_0"]
amsr2:
tuple:
- timedelta: "-6h"
variables:
amsr2_h180: ["rawbt_1", "rawbt_2", "rawbt_3", "rawbt_4"]
"""

CONFIG = yaml.safe_load(CONFIG_YAML)


from anemoi.training.data.refactor.draft import sample_provider_factory


def show_yaml(structure):
return yaml.dump(structure, indent=2, sort_keys=False)


def show_json(structure):
return json.dumps(structure, indent=2, default=shorten_numpy)


def shorten_numpy(structure):
if isinstance(structure, np.ndarray):
return f"np.array({structure.shape})"
return structure


def get_base_seed():
"""Get a base seed for random number generation.
This is a placeholder function; replace with actual logic to get a base seed.
"""
return 42 # Example fixed seed, replace with actual logic as needed


class DOPDataset(IterableDataset):
def __init__(
self,
# config: dict,
shuffle: bool = True,
rollout: int = 1,
multistep: int = 1,
task: str = "training",
) -> None:

self.shuffle = shuffle
# self.config = config
self.rollout = rollout
self.multistep = multistep
self.task = task

# lazy init
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0

# additional state vars (lazy init)
self.n_samples_per_worker = 0
self.chunk_index_range: Optional[np.ndarray] = None
self.shuffle = shuffle
self.rng: Optional[np.random.Generator] = None
self.worker_id: int = -1

# "full" shuffling
self.data_indices: Optional[np.ndarray] = None

self.seed_comm_group_id = 0
self.seed_comm_num_groups = 1

training_context = {
"name": "training",
"sources": CONFIG["sources"]["training"],
"start": "2018-11-02",
"end": "2018-11-01",
}

self._sample_provider = sample_provider_factory(context=training_context, **CONFIG["sample"])
self._sample_provider = self._sample_provider.shuffle(seed=42)

# self.len = len(self._sample_provider)

def __get_sample(self, index: int):
"""Get a sample from the dataset."""
return self._sample_provider[index]

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.

This initialises after the worker process has been spawned.

Parameters
----------
n_workers : int
Number of workers
worker_id : int
Worker ID
"""
self.worker_id = worker_id

lenght = len(self._sample_provider)
# Divide this equally across shards (one shard per group!)
shard_size = lenght // self.seed_comm_num_groups
shard_start = self.seed_comm_group_id * shard_size
shard_end = min((self.seed_comm_group_id + 1) * shard_size, lenght)

shard_len = shard_end - shard_start
self.n_samples_per_worker = shard_len // n_workers

low = shard_start + worker_id * self.n_samples_per_worker
high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end)
self.chunk_index_range = np.arange(low, high, dtype=np.uint32)

seed = get_base_seed() # all workers get the same seed (so they all get the same index shuffle)
torch.manual_seed(seed)
random.seed(seed)
self.rng = np.random.default_rng(seed=seed)
sanity_rnd = self.rng.random(1)
print("Sanity check random number:", sanity_rnd)

def __iter__(self):
# no shuffle, just iterate over the chunk indices
for idx in self.chunk_index_range:
print(
f"VALIDATION: Worker {self.worker_id} (pid {os.getpid()}) fetching sample index {idx} ...",
)
yield self.__get_sample(idx)


def worker_init_func(worker_id: int) -> None:
"""Configures each dataset worker process.

Calls WeatherBenchDataset.per_worker_init() on each dataset object.

Parameters
----------
worker_id : int
Worker ID

Raises
------
RuntimeError
If worker_info is None
"""
worker_info = get_worker_info() # information specific to each worker process
if worker_info is None:
print("worker_info is None! Set num_workers > 0 in your dataloader!")
raise RuntimeError
dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process.
dataset_obj.per_worker_init(
n_workers=worker_info.num_workers,
worker_id=worker_id,
)


if __name__ == "__main__":

ds = DOPDataset(
# CONFIG,
shuffle=False,
rollout=1,
multistep=1,
task="training",
)

loader_params = {
"batch_size": 1, # must be 1 for the time being
"batch_sampler": None,
"num_workers": 2,
"pin_memory": False,
"worker_init_fn": worker_init_func,
# "collate_fn": None, # collator_wrapper(return_original_metadata=cfg_.dataloader.return_dates),
}

dl = torch.utils.data.DataLoader(ds, **loader_params, sampler=None)

for batch_idx, batch in enumerate(dl):
print("%s", batch)
if batch_idx >= 1:
break
45 changes: 24 additions & 21 deletions models/src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@
from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.shapes import apply_shard_shapes
from anemoi.models.distributed.shapes import get_shard_shapes
from anemoi.models.models.mult_encoder_processor_decoder import AnemoiMultiModel
from anemoi.models.preprocessing import Processors
from anemoi.utils.config import DotDict


def processor_factory(name_to_index, statistics, processors, **kwargs) -> list[list]:
from anemoi.models.preprocessing.normalizer import InputNormalizer

return [
[name, instantiate(cfg, name_to_index=name_to_index["variables"], statistics=statistics["variables"])]
for name, cfg in processors.items()
]


class AnemoiModelInterface(torch.nn.Module):
"""An interface for Anemoi models.

Expand Down Expand Up @@ -59,46 +69,39 @@ def __init__(
self,
*,
config: DotDict,
sample_provider,
graph_data: HeteroData,
statistics: dict,
data_indices: dict,
# data_indices: dict,
metadata: dict,
supporting_arrays: dict = None,
truncation_data: dict,
) -> None:
super().__init__()
self.config = config
self.id = str(uuid.uuid4())
self.multi_step = self.config.training.multistep_input
self.sample_provider = sample_provider
self.graph_data = graph_data
self.statistics = statistics
self.truncation_data = truncation_data
self.metadata = metadata
self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {}
self.data_indices = data_indices
self.supporting_arrays = {}
self._build_model()

def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Instantiate processors
processors = [
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)]
for name, processor in self.config.data.processors.items()
]
preprocessors = self.sample_provider.apply(processor_factory)

# Assign the processor list pre- and post-processors
self.pre_processors = Processors(processors)
self.post_processors = Processors(processors, inverse=True)
self.input_pre_processors = Processors(preprocessors["input"].processor_factory)
self.target_pre_processors = Processors(preprocessors["target"].processor_factory)
self.target_post_processors = Processors(preprocessors["target"].processor_factory, inverse=True)
# TODO: Implemente structure.processor_factory (not only at LeafStructure)

# Instantiate the model
self.model = instantiate(
self.config.model.model,
self.model = AnemoiMultiModel(
# self.config.model.model,
model_config=self.config,
data_indices=self.data_indices,
statistics=self.statistics,
sample_provider=self.sample_provider,
graph_data=self.graph_data,
truncation_data=self.truncation_data,
_recursive_=False, # Disables recursive instantiation by Hydra
# truncation_data=self.truncation_data,
# _recursive_=False, # Disables recursive instantiation by Hydra
)

# Use the forward method of the model directly
Expand Down
Loading
Loading