Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 262 additions & 0 deletions models/src/anemoi/models/schemas/data_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from collections.abc import Iterable
from enum import Enum
from typing import Any
from typing import Union

from pydantic import Field
from pydantic import RootModel
from pydantic import TypeAdapter
from pydantic import ValidationError
from pydantic import field_validator
from pydantic import model_validator

from anemoi.utils.schemas import BaseModel


class NormalizerSchema(BaseModel):
default: Union[str, None] = Field(literals=["mean-std", "std", "min-max", "max", "none"])
"""Normalizer default method to apply"""
remap: Union[dict[str, str], None] = Field(default_factory=dict)
"""Dictionary for remapping variables"""
std: Union[list[str], None] = Field(default_factory=list)
"""Variables to normalise with std"""
mean_std: Union[list[str], None] = Field(default_factory=list, alias="mean-std")
"""Variables to mormalize with mean-std"""
min_max: Union[list[str], None] = Field(default_factory=list, alias="min-max")
"""Variables to normalize with min-max."""
max: Union[list[str], None] = Field(default_factory=list)
"""Variables to normalize with max."""
none: Union[list[str], None] = Field(default_factory=list)
"""Variables not to be normalized."""


class ImputerSchema(BaseModel):
default: str = Field(literals=["none", "mean", "stdev"])
"Imputer default method to apply."
maximum: Union[list[str], None]
minimum: Union[list[str], None]
none: Union[list[str], None] = Field(default_factory=list)
"Variables not to be imputed."


class ConstantImputerSchema(RootModel[dict[Any, Any]]):
"""Schema for ConstantImputer.

Expects the config to have keys corresponding to available statistics
and values as lists of variables to impute.:
```
default: "none"
1:
- y
5.0:
- x
3.14:
- q
none:
- z
- other
```
"""

@field_validator("root")
@classmethod
def validate_entries(cls, values: dict[Union[int, float, str], Union[str, list[str]]]) -> dict[Any, Any]:

for k, v in values.items():
if k == "default":
if not isinstance(v, (int, float)):
if v is None or v == "none" or v == "None":
continue
msg = f'"default" must map to a float or None, got {type(v).__name__}'
raise TypeError(msg)
elif k == "none":
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
msg = f'"none" must map to a list of strings, got {v}'
raise TypeError(msg)

# Accept numeric keys as int or float
elif isinstance(k, (int, float)):
if not isinstance(v, Iterable) or isinstance(v, (str, bytes)):
msg = f'Key "{k}" must map to a list of strings, got {v}'
raise TypeError(msg)
if not all(isinstance(i, str) for i in v):
msg = f'Key "{k}" must map to a list of strings, got {v}'
raise TypeError(msg)

# Reject all other keys
else:
msg = f'Key "{k}" must be either a number, "none" or "default", got {type(k).__name__}'
raise TypeError(msg)

return values


class PostprocessorSchema(BaseModel):
default: str = Field(literals=["none", "relu", "hardtanh", "hardtanh_0_1"])
"Postprocessor default method to apply."
relu: Union[list[str], None] = Field(default_factory=list)
"Variables to postprocess with relu."
hardtanh: Union[list[str], None] = Field(default_factory=list)
"Variables to postprocess with hardtanh."
hardtanh_0_1: Union[list[str], None] = Field(default_factory=list)
"Variables to postprocess with hardtanh in range [0, 1]."
none: Union[list[str], None] = Field(default_factory=list)
"Variables not to be postprocessed."


class NormalizedReluPostprocessorSchema(RootModel[dict[Any, Any]]):
"""Schema for the NormalizedReluPostProcessor.

Expects the config to have keys corresponding to customizable thresholds and lists of variables
to postprocess and a normalizer to apply to thresholds.:
```
normalizer: 'mean-std'
1:
- y
0:
- x
3.14:
- q
```
"""

@field_validator("root")
@classmethod
def validate_entries(cls, values: dict[Union[int, float, str], Union[str, list[str]]]) -> dict[Any, Any]:

for k, v in values.items():

if k == "normalizer":
if not isinstance(v, str): #
msg = f'"normalizer" must map to a string, got {v}'
raise TypeError(msg)
if v not in ["none", "mean-std", "std", "min-max", "max"]:
msg = f'"normalizer" must be one of "none", "mean-std", "std", "min-max", "max", got {v}'
raise ValueError(msg)

# Accept numeric keys as int or float
elif isinstance(k, (int, float)):
if not isinstance(v, Iterable) or isinstance(v, (str, bytes)):
msg = f'Key "{k}" must map to a list of strings, got {v}'
raise TypeError(msg)
if not all(isinstance(i, str) for i in v):
msg = f'Key "{k}" must map to a list of strings, got {v}'
raise TypeError(msg)

# Reject all other keys
else:
msg = f'Key "{k}" must be either a number, "normalizer", got {type(k).__name__}'
raise TypeError(msg)

return values


class ConditionalZeroPostprocessorSchema(RootModel[dict[Any, Any]]):
"""Schema for ConditionalZeroPostProcessor.

Expects the config to have keys corresponding to customizable values and lists of variables
to postprocess and a variable to use for postprocessing.:

```
default: "none"
remap: "x"
0:
- y
5.0:
- x
3.14:
- q
```

If "x" is zero, "y" will be postprocessed with 0, "x" with 5.0 and "q" with 3.14.
"""

@field_validator("root")
@classmethod
def validate_entries(cls, values: dict[Union[int, float, str], Union[str, list[str]]]) -> dict[Any, Any]:

for k, v in values.items():
if k == "default":
if not isinstance(v, (int, float)):
if v is None or v == "none" or v == "None":
continue
msg = f'"default" must map to a float or None, got {type(v).__name__}'
raise TypeError(msg)
elif k == "remap":
if not isinstance(v, str):
msg = f'"remap" must map to a strings, got {v}'
raise TypeError(msg)

# Accept numeric keys as int or float
elif isinstance(k, (int, float)):
if not isinstance(v, Iterable) or isinstance(v, (str, bytes)):
msg = f'Key "{k}" must map to a list of strings, got {v}'
raise TypeError(msg)
if not all(isinstance(i, str) for i in v):
msg = f'Key "{k}" must map to a list of strings, got {v}'
raise TypeError(msg)

# Reject all other keys
else:
msg = f'Key "{k}" must be either a number, "none" or "default", got {type(k).__name__}'
raise TypeError(msg)

return values


class RemapperSchema(BaseModel):
default: str = Field(literals=["none", "log1p", "sqrt", "boxcox"])
"Remapper default method to apply."
none: Union[list[str], None] = Field(default_factory=list)
"Variables not to be remapped."


class PreprocessorTarget(str, Enum):
normalizer = "anemoi.models.preprocessing.normalizer.InputNormalizer"
imputer = "anemoi.models.preprocessing.imputer.InputImputer"
const_imputer = "anemoi.models.preprocessing.imputer.ConstantImputer"
remapper = "anemoi.models.preprocessing.remapper.Remapper"
postprocessor = "anemoi.models.preprocessing.postprocessor.Postprocessor"
conditional_zero_postprocessor = "anemoi.models.preprocessing.postprocessor.ConditionalZeroPostprocessor"
normalized_relu_postprocessor = "anemoi.models.preprocessing.postprocessor.NormalizedReluPostprocessor"


target_to_schema = {
PreprocessorTarget.normalizer: NormalizerSchema,
PreprocessorTarget.imputer: ImputerSchema,
PreprocessorTarget.const_imputer: ConstantImputerSchema,
PreprocessorTarget.remapper: RemapperSchema,
PreprocessorTarget.postprocessor: PostprocessorSchema,
PreprocessorTarget.conditional_zero_postprocessor: ConditionalZeroPostprocessorSchema,
PreprocessorTarget.normalized_relu_postprocessor: NormalizedReluPostprocessorSchema,
}


class PreprocessorSchema(BaseModel, validate_assignment=False):
target_: PreprocessorTarget = Field(..., alias="_target_")
"Processor object from anemoi.models.preprocessing.[normalizer|imputer|remapper]."
config: Union[dict, NormalizerSchema, ImputerSchema, PostprocessorSchema, RemapperSchema]
"Target schema containing processor methods."

@model_validator(mode="after")
def schema_consistent_with_target(self) -> type["PreprocessorSchema"]:
schema_cls = target_to_schema.get(self.target_)
if schema_cls is None:
error_msg = f"Unknown target: {self.target_}"
raise ValidationError(error_msg)

validated = TypeAdapter(schema_cls).validate_python(self.config)
# If it's a RootModel (like ConstantImputerSchema), extract the root dict
if hasattr(validated, "root"):
self.config = validated.root

return self
6 changes: 3 additions & 3 deletions models/src/anemoi/models/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedR
return self


class LeakyNormalizedReluBoundingSchema(NormalizedReluBoundingSchema):
target_: Literal["anemoi.models.layers.bounding.LeakyNormalizedReluBounding"] = Field(..., alias="_target_")
class NormalizedLeakyReluBoundingSchema(NormalizedReluBoundingSchema):
target_: Literal["anemoi.models.layers.bounding.NormalizedLeakyReluBounding"] = Field(..., alias="_target_")
"Leaky normalized Relu bounding object defined in anemoi.models.layers.bounding."


Expand All @@ -138,7 +138,7 @@ class LeakyNormalizedReluBoundingSchema(NormalizedReluBoundingSchema):
HardtanhBoundingSchema,
LeakyHardtanhBoundingSchema,
NormalizedReluBoundingSchema,
LeakyNormalizedReluBoundingSchema,
NormalizedLeakyReluBoundingSchema,
],
Field(discriminator="target_"),
]
Expand Down
90 changes: 90 additions & 0 deletions models/tests/schemas/test_data_processors_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from anemoi.models.schemas.data_processor import ImputerSchema
from anemoi.models.schemas.data_processor import NormalizerSchema
from anemoi.models.schemas.data_processor import PostprocessorSchema
from anemoi.models.schemas.data_processor import PreprocessorSchema
from anemoi.models.schemas.data_processor import PreprocessorTarget
from anemoi.models.schemas.data_processor import RemapperSchema


def test_preprocessor_with_raw_dict():
raw_config = {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]}
schema = PreprocessorSchema(_target_="anemoi.models.preprocessing.normalizer.InputNormalizer", config=raw_config)

assert schema.target_ == "anemoi.models.preprocessing.normalizer.InputNormalizer"
assert schema.config == raw_config


def test_preprocessor_with_normalizer_instance():
normalizer_instance = NormalizerSchema(default="std", remap={"x": "z", "y": "x"})
schema = PreprocessorSchema(
_target_="anemoi.models.preprocessing.normalizer.InputNormalizer", config=normalizer_instance
)

assert schema.target_ == "anemoi.models.preprocessing.normalizer.InputNormalizer"
assert isinstance(schema.config, NormalizerSchema)
assert schema.config.default == "std"
assert schema.config.remap == {"x": "z", "y": "x"}


def test_preprocessor_with_imputer_dict():
raw_config = {"default": "none", "maximum": ["x"], "none": ["z"], "minimum": ["q"]}
schema = PreprocessorSchema(_target_=PreprocessorTarget.imputer, config=raw_config)
assert schema.target_ == PreprocessorTarget.imputer
assert schema.config["default"] == "none"


def test_preprocessor_with_imputer_instance():
instance = ImputerSchema(default="none", maximum=["x"], minimum=["q"], none=["z"])
schema = PreprocessorSchema(_target_=PreprocessorTarget.imputer, config=instance)
assert isinstance(schema.config, ImputerSchema)
assert schema.config.maximum == ["x"]


def test_preprocessor_with_constant_imputer_dict():
raw_config = {"default": 1.0, 1.0: ["x"], 5.0: ["y"], "none": ["z"]}
schema = PreprocessorSchema(_target_=PreprocessorTarget.const_imputer, config=raw_config)
assert schema.target_ == PreprocessorTarget.const_imputer
assert schema.config["default"] == 1.0
assert schema.config[1.0] == ["x"]


def test_preprocessor_with_postprocessor_dict():
raw_config = {"default": "hardtanh_0_1", "hardtanh_0_1": ["x"], "none": ["y"]}
schema = PreprocessorSchema(_target_=PreprocessorTarget.postprocessor, config=raw_config)
assert schema.config["default"] == "hardtanh_0_1"


def test_preprocessor_with_postprocessor_instance():
instance = PostprocessorSchema(default="relu", relu=["x"], none=["z"])
schema = PreprocessorSchema(_target_=PreprocessorTarget.postprocessor, config=instance)
assert isinstance(schema.config, PostprocessorSchema)
assert schema.config.relu == ["x"]


def test_preprocessor_with_conditional_zero_postprocessor_dict():
raw_config = {"default": 0.0, "remap": "ref_var", 0: ["x"], 1: ["y"]}
schema = PreprocessorSchema(_target_=PreprocessorTarget.conditional_zero_postprocessor, config=raw_config)
assert schema.config["remap"] == "ref_var"


def test_preprocessor_with_normalized_relu_postprocessor_dict():
raw_config = {"normalizer": "mean-std", 1.0: ["x"], 0: ["y"]}
schema = PreprocessorSchema(_target_=PreprocessorTarget.normalized_relu_postprocessor, config=raw_config)
print(schema)
assert schema.config["normalizer"] == "mean-std"


def test_preprocessor_with_remapper_instance():
instance = RemapperSchema(default="log1p", none=["d", "q"])
schema = PreprocessorSchema(_target_=PreprocessorTarget.remapper, config=instance)
assert isinstance(schema.config, RemapperSchema)
assert "d" in schema.config.none
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only

from anemoi.training.utils.checkpoint import check_classes
from anemoi.utils.checkpoints import save_metadata

if TYPE_CHECKING:
Expand Down Expand Up @@ -151,6 +152,14 @@ def _get_inference_checkpoint_filepath(self, filepath: str) -> str:
"""Defines the filepath for the inference checkpoint."""
return Path(filepath).parent / Path("inference-" + str(Path(filepath).name))

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Check that model's metadata does not contain Pydantic schemas references."""
del pl_module

if trainer.is_global_zero:
model = self._torch_drop_down(trainer)
check_classes(model)

def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None:
if trainer.is_global_zero:
model = self._torch_drop_down(trainer)
Expand Down
Loading
Loading