Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
2f5b3d0
bugfix: interpolator was previously interpolating 36h windows at 6h i…
Rilwan-Adewoyin Sep 1, 2025
6410fb0
bulk update towardsorking accumulations
Rilwan-Adewoyin Sep 2, 2025
2535027
bulk update towardsorking accumulations
Rilwan-Adewoyin Sep 2, 2025
174476b
improve energy accum logic in order to work with inference
Rilwan-Adewoyin Sep 6, 2025
2fa1140
Refactor modules: LeftBoundaryZero, interpolator config, accumulation…
Rilwan-Adewoyin Sep 14, 2025
4c978f5
bugfix: update SetToZero arguments and the related Schema for pydanti…
Rilwan-Adewoyin Sep 15, 2025
d97d5c2
Delete training/src/anemoi/training/config/test_benchmark.py
Rilwan-Adewoyin Sep 18, 2025
8d3540a
Update evaluation.yaml - remove null from diagnostics.log.wand.entity
Rilwan-Adewoyin Sep 18, 2025
bcea966
Update evaluation.yaml - remove null from diagnostics.log.mlflow.trac…
Rilwan-Adewoyin Sep 18, 2025
d43fe0f
Delete models/src/anemoi/models/models/test_benchmark.py
Rilwan-Adewoyin Sep 18, 2025
63cadfe
Update interpolator.py fixing errors and improving the annotation res…
Rilwan-Adewoyin Sep 18, 2025
1d854ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2025
c6930d9
Merge branch 'main' into feat/402-time-interpolation-mass-conserving-…
Rilwan-Adewoyin Sep 18, 2025
4875b2d
move data subconfig for interpolator to seperate file
Rilwan-Adewoyin Oct 3, 2025
28322ee
fix: corrected assertion statement
Rilwan-Adewoyin Oct 3, 2025
cc41ad1
Changing default latent_skip to True, reflecting experiments which sh…
Rilwan-Adewoyin Oct 3, 2025
dedb040
Add pydantic checks which ensure the explicit_times target and input …
Rilwan-Adewoyin Oct 3, 2025
42802c1
added pydantic checks for SetToZero overwriter
Rilwan-Adewoyin Oct 3, 2025
b3711da
change name from SetToZero to ZeroOverwriter
Rilwan-Adewoyin Oct 3, 2025
bba4129
added tests for Zero Overwriter
Rilwan-Adewoyin Oct 3, 2025
f380454
fix merge conflict
Rilwan-Adewoyin Oct 3, 2025
63d6f0d
Add test to ensure mass conservation works
Rilwan-Adewoyin Oct 3, 2025
0fa03a3
Add pydantic checks to ensure correctness of whole config if mass con…
Rilwan-Adewoyin Oct 7, 2025
2402778
bugfix: updated pydantic schema checks for ZeroOverwirte to cover any…
Rilwan-Adewoyin Oct 8, 2025
fcee603
bugfix: schema for ZeroOverwrite assumed list instead of Mutable Sequ…
Rilwan-Adewoyin Oct 8, 2025
67090f2
bugfix - unindented function
Rilwan-Adewoyin Oct 8, 2025
ef50380
nomenclature update: replace input_times with num_input_times
Rilwan-Adewoyin Oct 8, 2025
e68d1b9
nomenclature update: replace input_times with num_input_times
Rilwan-Adewoyin Oct 8, 2025
41b5ac5
Use MutableSequence in assertions, Add Interppolation Model to ModelS…
Rilwan-Adewoyin Oct 9, 2025
4d5c206
fix the MutableMapping import
Rilwan-Adewoyin Oct 9, 2025
bb8bbd7
Added docstring to ZeroOverwriter class
Rilwan-Adewoyin Oct 10, 2025
8dbdd97
Adding docs for ZeroOverwriter and interpolator set up
Rilwan-Adewoyin Oct 10, 2025
a0b9f47
polish docs - interpolator setup
Rilwan-Adewoyin Oct 10, 2025
c1ac819
polish docs - interpolator setup
Rilwan-Adewoyin Oct 10, 2025
f4211bf
polishing the ZeroOverwriter docs
Rilwan-Adewoyin Oct 10, 2025
e8778aa
Reinstate utility of timeincrememnt for interpolator dataset/dataloading
Rilwan-Adewoyin Oct 10, 2025
db40988
revert: edits to schema which removed def schema_consistent_with_target
Rilwan-Adewoyin Oct 10, 2025
e086f89
Remove redundant note and LOGGER.warning from ZeroOverwriter instanti…
Rilwan-Adewoyin Oct 10, 2025
da2b4f2
polish docs for interplator setup
Rilwan-Adewoyin Oct 10, 2025
2291d37
updated ordering of params in test
Rilwan-Adewoyin Oct 10, 2025
574e18a
updated ordering of params in test
Rilwan-Adewoyin Oct 10, 2025
663b886
added error messages
Rilwan-Adewoyin Oct 10, 2025
2256a63
add a test for setting up the mass conservation accumulation
Rilwan-Adewoyin Oct 16, 2025
01a72c3
add tp cp to std normalization
Rilwan-Adewoyin Oct 16, 2025
50d6f78
Updating getting started schema
Rilwan-Adewoyin Oct 17, 2025
26e8d94
add ability for pydantic cross subschemea validation and refactor int…
Rilwan-Adewoyin Oct 17, 2025
3041ced
update nomenclature for cross subschema validation
Rilwan-Adewoyin Oct 17, 2025
37bd2fb
nomenclature update: time_index to time_indicies
Rilwan-Adewoyin Oct 17, 2025
185a03a
replace Logger.warnings with assertions
Rilwan-Adewoyin Oct 17, 2025
6885b6e
Merge remote-tracking branch 'origin/main' into feat/402-time-interpo…
Rilwan-Adewoyin Oct 17, 2025
e99f47a
bugfix: ensure time interpolator integration tests work
Rilwan-Adewoyin Oct 17, 2025
efc7a06
Merge branch 'main' into feat/402-time-interpolation-mass-conserving-…
Rilwan-Adewoyin Oct 17, 2025
befe686
Merge branch 'main' into feat/402-time-interpolation-mass-conserving-…
Rilwan-Adewoyin Oct 22, 2025
acfb6f2
removed text
Rilwan-Adewoyin Oct 22, 2025
99923bc
Currect the signature of the _step function
Rilwan-Adewoyin Oct 22, 2025
b16e868
change default model for interpolator from transformer to graphtransf…
Rilwan-Adewoyin Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions models/docs/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,44 @@ The module contains the following classes:
:members:
:no-undoc-members:
:show-inheritance:

****************
ZeroOverwriter
****************

Overwrite selected timesteps of specified input variables with zero.

This preprocessor operates on inputs before the model and is
model-independent. It is useful whenever a variable should be reset to
zero at certain timesteps within each input window (for example, for
accumulated or windowed variables).

Example
=======

To include in data config (Hydra/YAML) to set var_a and var_b to zero at
timesteps 0 and 3, and var_c and var_d to zero at timesteps 0 and 4.

.. code:: yaml

processors:
zero_overwriter:
_target_: anemoi.models.preprocessing.overwriter.ZeroOverwriter
config:
groups:
- vars:
- "var_a"
- "var_b"
time_indices: [0, 3]
- vars:
- "var_c"
- "var_d"
time_indices: [0, 4]

API
===

.. autoclass:: anemoi.models.preprocessing.overwriter.ZeroOverwriter
:members:
:no-undoc-members:
:show-inheritance:
249 changes: 247 additions & 2 deletions models/src/anemoi/models/models/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
import einops
import torch
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
from torch.nn import functional as F
from torch_geometric.data import HeteroData

from anemoi.models.distributed.graph import gather_tensor
from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.shapes import apply_shard_shapes
from anemoi.models.distributed.shapes import get_or_apply_shard_shapes
from anemoi.models.distributed.shapes import get_shard_shapes
from anemoi.models.models import AnemoiModelEncProcDec
Expand Down Expand Up @@ -52,7 +56,7 @@ def __init__(
self.num_target_forcings = (
len(model_config.training.target_forcing.data) + model_config.training.target_forcing.time_fraction
)
self.input_times = len(model_config.training.explicit_times.input)
self.num_input_times = len(model_config.training.explicit_times.input)
super().__init__(
model_config=model_config,
data_indices=data_indices,
Expand All @@ -64,10 +68,12 @@ def __init__(
self.latent_skip = model_config.model.latent_skip
self.grid_skip = model_config.model.grid_skip

self.setup_mass_conserving_accumulations(data_indices, model_config)

# Overwrite base class
def _calculate_input_dim(self):
return (
self.input_times * self.num_input_channels
self.num_input_times * self.num_input_channels
+ self.node_attributes.attr_ndims[self._graph_name_data]
+ self.num_target_forcings
)
Expand Down Expand Up @@ -184,3 +190,242 @@ def forward(
x_out = self._assemble_output(x_out, x_skip, batch_size, ensemble_size, x.dtype)

return x_out

def predict_step(
self,
batch: torch.Tensor,
pre_processors: nn.Module,
post_processors: nn.Module,
multi_step: int,
model_comm_group: Optional[ProcessGroup] = None,
gather_out: bool = True,
**kwargs,
) -> Tensor:
"""Prediction step for the model.

Base implementation applies pre-processing, performs a forward pass, and applies post-processing.
Subclasses can override this for different behavior (e.g., sampling for diffusion models).

Parameters
----------
batch : torch.Tensor
Input batched data (before pre-processing)
pre_processors : nn.Module,
Pre-processing module
post_processors : nn.Module,
Post-processing module
multi_step : int,
Number of input timesteps
model_comm_group : Optional[ProcessGroup]
Process group for distributed training
gather_out : bool
Whether to gather output tensors across distributed processes
**kwargs
Additional arguments

Returns
-------
Tensor
Model output (after post-processing)
"""
with torch.no_grad():

assert (
len(batch.shape) == 5
), f"The input tensor has an incorrect shape: expected a 5-dimensional tensor, got {batch.shape}!"

x_boundaries = pre_processors(batch, in_place=False) # batch should be the input variables only already

# Handle distributed processing
grid_shard_shapes = None
if model_comm_group is not None:
shard_shapes = get_shard_shapes(x_boundaries, -2, model_comm_group)
grid_shard_shapes = [shape[-2] for shape in shard_shapes]
x_boundaries = shard_tensor(x_boundaries, -2, shard_shapes, model_comm_group)

target_forcing = kwargs.get(
"target_forcing", None
) # shape(bs, interpolation_steps, ens, grid, forcing_dim)
interpolation_steps = target_forcing.shape[1]

output_shape = (
batch.shape[0],
target_forcing.shape[1],
batch.shape[2],
batch.shape[3],
)
# Perform forward pass
# TODO: add the same logic as in _step here e.g. iterative forwards to get the multiple y_hats

for i in range(interpolation_steps):
y_pred = self.forward(
x_boundaries,
model_comm_group=model_comm_group,
grid_shard_shapes=grid_shard_shapes,
target_forcing=target_forcing[:, i],
)

if i == 0:
output_shape = output_shape = (
batch.shape[0],
target_forcing.shape[1],
batch.shape[2],
batch.shape[3],
y_pred.shape[-1],
)
y_preds = batch.new_zeros(output_shape)

y_preds[:, i] = y_pred

include_right_boundary = kwargs.get("include_right_boundary", False)
if self.map_accum_indices is not None:

y_preds = self.resolve_mass_conservations(
y_preds, x_boundaries, include_right_boundary=include_right_boundary
)
elif include_right_boundary:
y_preds = torch.cat([y_preds, x_boundaries[:, -1:, ...]], dim=1)

# Apply post-processing
y_preds = post_processors(y_preds, in_place=False)

# Gather output if needed
if gather_out and model_comm_group is not None:
y_preds = gather_tensor(
y_preds, -2, apply_shard_shapes(y_preds, -2, grid_shard_shapes), model_comm_group
)

return y_preds

def resolve_mass_conservations(self, y_preds, x_input, include_right_boundary=False) -> torch.Tensor:
"""Enforce a mass-conservation constraint for selected output variables by redistributing a
known total (taken from input constraints) across the time dimension using softmax weights
derived from the model's logits.

Parameters
----------
y_preds : torch.Tensor
Model outputs of shape (B, T, E, G, V_out).
The subset `target_indices` inside V_out are the accumulated variables constrained to
sum to the input totals.
x_input : torch.Tensor
Input tensor compatible with `y_preds`. Constraint totals are read from the right boundary:
x_input[:, -1:, ..., input_constraint_indxs] with shape (B, 1, E, G, V_acc).
include_right_boundary : bool, optional
If False, distribute the constraint over the existing T steps.
If True, append an extra (T+1)-th step representing the right boundary and distribute over T+1 steps;
non-target outputs at that boundary are copied from inputs.

Returns
-------
torch.Tensor
`y_preds` with the constrained target variables replaced by a softmax-weighted allocation that conserves
the total. Shape is (B, T, E, G, V_out) or (B, T+1, E, G, V_out) if `include_right_boundary` is True.
"""

# Indices mapping:
# - input_constraint_indxs: channels in x_input containing the total "mass" to conserve
# - target_indices: corresponding output channels in y_preds that must sum to that mass
input_constraint_indxs = self.map_accum_indices["constraint_idxs"]
target_indices = self.map_accum_indices["target_idxs"]

# Extract logits for the "accumulated" target variables: (B, T, E, G, V_acc)
logits = y_preds[..., target_indices]

# Create a zero "anchor" slice along the time axis (B, 1, E, G, V_acc).
# Appending this before softmax creates T+1 slots whose weights sum to 1.
# The last slot can represent the right-boundary share.
zeros = torch.zeros_like(logits[:, 0:1])

# Compute normalized weights along time: (B, T+1, E, G, V_acc)
# Note: softmax dim=1 (time). Weights across time sum to 1 per (B, E, G, V_acc).
weights = F.softmax(torch.cat([logits, zeros], dim=1), dim=1)

if not include_right_boundary:
# We are *not* including the explicit right-boundary step in the output,
# so drop the last (boundary) slot and keep T weights: (B, T, E, G, V_acc)
weights = weights[:, :-1]

# The constraint "total mass" comes from the *last* input time slice:
# shape (B, 1, E, G, V_acc). This broadcasts over time when multiplied by weights.
constraints = x_input[:, -1:, ..., input_constraint_indxs]

# Replace target outputs with the softmax-weighted allocation of the constraint.
# For each (B, E, G, V_acc), the T values sum to <= the total constraint because
# we dropped the (T+1)-th boundary slot.
y_preds[..., target_indices] = weights * constraints

else:
# --- Include the right boundary as an explicit (T+1)-th step ---

# Identify output channels that are *not* in the target set,
# so we can copy their right-boundary values from inputs.
y_index_ex_target_indices = [
outp_idx
for vname, outp_idx in self.data_indices.model.output.name_to_index.items()
if outp_idx not in target_indices
]

# Map those same (non-target) outputs to their positions in the model *input*
# so we can source their boundary values from x_input.
data_indices_model_input_model_output = [
self.data_indices.model.input.name_to_index[vname]
for vname, outp_idx in self.data_indices.model.output.name_to_index.items()
if outp_idx not in target_indices
]

# Append an all-zero frame along time so y_preds has T+1 steps,
# matching the weights shape (B, T+1, E, G, V_*).
y_preds = torch.cat([y_preds, torch.zeros_like(y_preds[:, 0:1])], dim=1)

# Right-boundary constraint totals (B, 1, E, G, V_acc)
constraints = x_input[:, -1:, ..., input_constraint_indxs]

# Allocate constraint across T+1 steps for the target variables.
# For each (B, E, G, V_acc), these (T+1) values sum to the total constraint.
y_preds_accum = weights * constraints # (B, T+1, E, G, V_acc)

# For *non-target* variables, set the (T+1)-th step to the right-boundary input values.
# This preserves/copies boundary conditions for outputs we are not re-allocating.
y_preds[:, -1:, ..., y_index_ex_target_indices] = x_input[
:, -1:, ..., data_indices_model_input_model_output
]

# Write the allocated values back into the target channels across all T+1 steps.
y_preds[..., target_indices] = y_preds_accum

return y_preds

def setup_mass_conserving_accumulations(self, data_indices: dict, config: dict):

# Mass-conserving accumulations: expose the config mapping on the underlying model and
# prepare aligned index lists. Each mapping pairs an output variable (prediction target)
# with an input constraint variable (accumulation/forcing), which we validate and index below.
self.map_mass_conserving_accums = getattr(config.model, "mass_conserving_accumulations", None)
if self.map_mass_conserving_accums is None:
self.map_accum_indices = None
else:
target_idx_list: list[int] = []
constraint_idx_list: list[int] = []
for output_varname, input_constraint_varname in self.map_mass_conserving_accums.items():
assert (
input_constraint_varname in data_indices.data._forcing
), f"Input constraint variable {input_constraint_varname} not found in data indices forcing variables."
assert (
output_varname in data_indices.model.output.name_to_index
), f"Output variable {output_varname} not found in data indices output variables."

target_idx_list.append(data_indices.model.output.name_to_index[output_varname])
constraint_idx_list.append(data_indices.model.input.name_to_index[input_constraint_varname])

self.map_accum_indices = torch.nn.ParameterDict(
{
"target_idxs": torch.nn.Parameter(
torch.tensor(target_idx_list, dtype=torch.long), requires_grad=False
),
"constraint_idxs": torch.nn.Parameter(
torch.tensor(constraint_idx_list, dtype=torch.long),
requires_grad=False,
),
},
)
Loading