Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
671eaca
feat: add lat weighted attribute
JPXKQX Mar 10, 2025
a6e5554
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 2, 2025
0f41de4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2025
2579abb
add imports
JPXKQX Apr 2, 2025
1a654f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2025
1611fd3
add area weights for rectilinear grids
JPXKQX Apr 8, 2025
ecc31ee
Merge branch 'feature/lat-weighted-attr' of https://github.com/ecmwf/…
JPXKQX Apr 8, 2025
05433d7
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 8, 2025
7ac5ae9
fix: update dtypes
JPXKQX Apr 8, 2025
b9356b2
udpated schema
JPXKQX Apr 8, 2025
8eaf1d3
remove Polynomial area weigts
JPXKQX Apr 8, 2025
4ead97d
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 15, 2025
47dba5e
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 23, 2025
3ee0ca3
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX May 14, 2025
71d190e
add assert to test
JPXKQX May 14, 2025
4e1ec15
add tests
JPXKQX May 14, 2025
7a9da78
fix: switch lat_1 and lat_2
JPXKQX May 14, 2025
f7e6107
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX May 21, 2025
b0915d3
assert latitude sorting
JPXKQX Jun 4, 2025
aae6d5d
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 4, 2025
229b52a
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 5, 2025
d8bc939
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 13, 2025
650d97c
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 18, 2025
bd873b0
add class description
JPXKQX Jun 18, 2025
f608af6
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jul 11, 2025
ba18d09
Merge branch 'main' into feature/lat-weighted-attr
mchantry Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions graphs/src/anemoi/graphs/nodes/attributes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from .area_weights import CosineLatWeightedAttribute
from .area_weights import IsolatitudeAreaWeights
from .area_weights import PlanarAreaWeights
from .area_weights import SphericalAreaWeights
from .area_weights import UniformWeights
Expand All @@ -25,4 +27,6 @@
"BooleanAndMask",
"BooleanNot",
"BooleanOrMask",
"CosineLatWeightedAttribute",
"IsolatitudeAreaWeights",
]
89 changes: 89 additions & 0 deletions graphs/src/anemoi/graphs/nodes/attributes/area_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from __future__ import annotations

import logging
from abc import ABC
from abc import abstractmethod

import numpy as np
import torch
Expand All @@ -18,6 +20,7 @@
from scipy.spatial import Voronoi
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian
from anemoi.graphs.nodes.attributes.base_attributes import BaseNodeAttribute

Expand Down Expand Up @@ -253,3 +256,89 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
result.sum(),
)
return torch.from_numpy(result)


class BaseLatWeightedAttribute(BaseNodeAttribute, ABC):

@abstractmethod
def compute_latitude_weight(self, latitudes: np.ndarray) -> np.ndarray: ...

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> torch.Tensor:
lats_rad = nodes.x[:, 0].cpu().numpy()
weights = self.compute_latitude_weight(lats_rad)
return torch.from_numpy(weights)


class CosineLatWeightedAttribute(BaseLatWeightedAttribute):
"""Latitude-weighting of the node attributes as a function of a polynomial.

Attributes
----------
min_value : float
Minimum value of the weights when the latitude is -pi/2 or pi/2 radians.
max_value : float
Maximum value of the weights when the latitude is 0 radians.
norm : str
Normalisation of the weights.

Methods
-------
compute(self, graph, nodes_name)
Compute the area attributes for each node.
"""

def __init__(
self,
min_value: float = 1e-3,
max_value: float = 1,
norm: str | None = None,
dtype: str = "float32",
) -> None:
super().__init__(norm, dtype)
self.min_value = min_value
self.max_value = max_value

def compute_latitude_weight(self, latitudes: np.ndarray) -> np.ndarray:
return (self.max_value - self.min_value) * np.cos(latitudes) + self.min_value


class IsolatitudeAreaWeights(BaseLatWeightedAttribute):
"""Latitude-weighted area weights for rectilinear grids.

Attributes
----------
norm : str
Normalisation of the weights.

Methods
-------
compute(self, graph, nodes_name)
Compute the area attributes for each node.

Notes
------
The area of a latitude band is
.. math::
A = 2\pi R(\sin(lat_2) - \sin(lat_1))
where R is the earth radius and lat_1, lat_2 are in radians.
"""

def compute_latitude_weight(self, latitudes: np.ndarray) -> np.ndarray:
# Get the latitudes defining the bands
unique_lats = np.sort(np.unique(latitudes))
divisory_lats = (unique_lats[1:] + unique_lats[:-1]) / 2
divisory_lats = np.concatenate([[-np.pi / 2], divisory_lats, [np.pi / 2]])

# Compute the latitude band area
lat_1 = divisory_lats[1:]
lat_2 = divisory_lats[:-1]
ring_area_km = 2 * np.pi * EARTH_RADIUS * (np.sin(lat_2) - np.sin(lat_1))

# Compute the number of points/nodes at each latitude band
lat_to_ring = {lat: idx for idx, lat in enumerate(unique_lats)}
lat_rings = np.array([lat_to_ring[lat] for lat in latitudes])
lat_counts = np.bincount(lat_rings, minlength=len(unique_lats))

# Compute the area of each node
area_km = dict(zip(unique_lats, ring_area_km / lat_counts))
return np.array([area_km[lat] for lat in latitudes])
10 changes: 10 additions & 0 deletions graphs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def graph_with_nodes() -> HeteroData:
return graph


@pytest.fixture
def graph_with_rectilinear_nodes() -> HeteroData:
graph = HeteroData()
num_lons, num_lats = 10, 10
lat_grid, lon_grid = np.meshgrid(np.linspace(-np.pi / 2, np.pi / 2, num_lats), np.linspace(0, 2 * np.pi, num_lons))
coords = torch.tensor(np.array([lat_grid.ravel(), lon_grid.ravel()]).T)
graph["test_nodes"].x = coords
return graph


@pytest.fixture
def graph_with_isolated_nodes() -> HeteroData:
graph = HeteroData()
Expand Down
17 changes: 17 additions & 0 deletions graphs/tests/nodes/attributes/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from typing import Type

import pytest
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.attributes import CosineLatWeightedAttribute
from anemoi.graphs.nodes.attributes import IsolatitudeAreaWeights
from anemoi.graphs.nodes.attributes import PlanarAreaWeights
from anemoi.graphs.nodes.attributes import SphericalAreaWeights
from anemoi.graphs.nodes.attributes import UniformWeights
from anemoi.graphs.nodes.attributes.base_attributes import BaseNodeAttribute


def test_uniform_weights(graph_with_nodes: HeteroData):
Expand Down Expand Up @@ -63,3 +68,15 @@ def test_spherical_area_weights_wrong_fill_value(fill_value: str):
"""Test attribute builder for SphericalAreaWeights with invalid fill_value."""
with pytest.raises(AssertionError):
SphericalAreaWeights(fill_value=fill_value)


@pytest.mark.parametrize("attr_class", [IsolatitudeAreaWeights, CosineLatWeightedAttribute])
def test_latweighted(attr_class: Type[BaseNodeAttribute], graph_with_rectilinear_nodes):
"""Test attribute builder for Lat with different fill values."""
node_attr_builder = attr_class(norm="l1")
weights = node_attr_builder.compute(graph_with_rectilinear_nodes, "test_nodes")

assert weights is not None
assert isinstance(weights, torch.Tensor)
assert weights.shape[0] == graph_with_rectilinear_nodes["test_nodes"].x.shape[0]
assert weights.dtype == node_attr_builder.dtype
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class PlanarAreaWeightSchema(BaseModel):
"anemoi.graphs.nodes.attributes.AreaWeights",
"anemoi.graphs.nodes.attributes.PlanarAreaWeights",
"anemoi.graphs.nodes.attributes.UniformWeights",
"anemoi.graphs.nodes.attributes.CosineLatWeightedAttribute",
"anemoi.graphs.nodes.attributes.IsolatitudeAreaWeights",
] = Field(..., alias="_target_")
"Implementation of the area of the nodes as the weights from anemoi.graphs.nodes.attributes."
norm: Literal["unit-max", "l1", "l2", "unit-sum", "unit-std"] = Field(example="unit-max")
Expand Down