Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
671eaca
feat: add lat weighted attribute
JPXKQX Mar 10, 2025
a6e5554
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 2, 2025
0f41de4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2025
2579abb
add imports
JPXKQX Apr 2, 2025
1a654f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2025
1611fd3
add area weights for rectilinear grids
JPXKQX Apr 8, 2025
ecc31ee
Merge branch 'feature/lat-weighted-attr' of https://github.com/ecmwf/…
JPXKQX Apr 8, 2025
05433d7
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 8, 2025
7ac5ae9
fix: update dtypes
JPXKQX Apr 8, 2025
b9356b2
udpated schema
JPXKQX Apr 8, 2025
8eaf1d3
remove Polynomial area weigts
JPXKQX Apr 8, 2025
4ead97d
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 15, 2025
47dba5e
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Apr 23, 2025
3ee0ca3
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX May 14, 2025
71d190e
add assert to test
JPXKQX May 14, 2025
4e1ec15
add tests
JPXKQX May 14, 2025
7a9da78
fix: switch lat_1 and lat_2
JPXKQX May 14, 2025
f7e6107
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX May 21, 2025
b0915d3
assert latitude sorting
JPXKQX Jun 4, 2025
aae6d5d
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 4, 2025
229b52a
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 5, 2025
d8bc939
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 13, 2025
650d97c
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jun 18, 2025
bd873b0
add class description
JPXKQX Jun 18, 2025
f608af6
Merge branch 'main' into feature/lat-weighted-attr
JPXKQX Jul 11, 2025
ba18d09
Merge branch 'main' into feature/lat-weighted-attr
mchantry Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphs/src/anemoi/graphs/nodes/attributes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .area_weights import CosineLatWeightedAttribute
from .area_weights import IsolatitudeAreaWeights
from .area_weights import MaskedPlanarAreaWeights
from .area_weights import PlanarAreaWeights
from .area_weights import SphericalAreaWeights
from .area_weights import UniformWeights
Expand Down
2 changes: 1 addition & 1 deletion graphs/src/anemoi/graphs/nodes/attributes/area_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def compute_latitude_weight(self, latitudes: np.ndarray) -> np.ndarray:
# Compute the latitude band area
lat_1 = divisory_lats[1:]
lat_2 = divisory_lats[:-1]
ring_area_km = 2 * np.pi * EARTH_RADIUS * (np.sin(lat_2) - np.sin(lat_1))
ring_area_km = 2 * np.pi * EARTH_RADIUS * (np.sin(lat_1) - np.sin(lat_2))

# Compute the number of points/nodes at each latitude band
lat_to_ring = {lat: idx for idx, lat in enumerate(unique_lats)}
Expand Down
28 changes: 26 additions & 2 deletions graphs/tests/nodes/attributes/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from anemoi.graphs.nodes.attributes import CosineLatWeightedAttribute
from anemoi.graphs.nodes.attributes import IsolatitudeAreaWeights
from anemoi.graphs.nodes.attributes import MaskedPlanarAreaWeights
from anemoi.graphs.nodes.attributes import PlanarAreaWeights
from anemoi.graphs.nodes.attributes import SphericalAreaWeights
from anemoi.graphs.nodes.attributes import UniformWeights
Expand Down Expand Up @@ -71,12 +72,35 @@ def test_spherical_area_weights_wrong_fill_value(fill_value: str):


@pytest.mark.parametrize("attr_class", [IsolatitudeAreaWeights, CosineLatWeightedAttribute])
def test_latweighted(attr_class: Type[BaseNodeAttribute], graph_with_rectilinear_nodes):
@pytest.mark.parametrize("norm", [None, "l1", "unit-max"])
def test_latweighted(attr_class: Type[BaseNodeAttribute], graph_with_rectilinear_nodes, norm: str):
"""Test attribute builder for Lat with different fill values."""
node_attr_builder = attr_class(norm="l1")
node_attr_builder = attr_class(norm=norm)
weights = node_attr_builder.compute(graph_with_rectilinear_nodes, "test_nodes")

assert weights is not None
assert isinstance(weights, torch.Tensor)
assert torch.all(weights >= 0)
assert weights.shape[0] == graph_with_rectilinear_nodes["test_nodes"].x.shape[0]
assert weights.dtype == node_attr_builder.dtype


def test_masked_planar_area_weights(graph_with_nodes: HeteroData):
"""Test attribute builder for PlanarAreaWeights."""
node_attr_builder = MaskedPlanarAreaWeights(mask_node_attr_name="interior_mask")
weights = node_attr_builder.compute(graph_with_nodes, "test_nodes")

assert weights is not None
assert isinstance(weights, torch.Tensor)
assert weights.shape[0] == graph_with_nodes["test_nodes"].x.shape[0]
assert weights.dtype == node_attr_builder.dtype

mask = graph_with_nodes["test_nodes"]["interior_mask"]
assert torch.all(weights[~mask] == 0)


def test_masked_planar_area_weights_fail(graph_with_nodes: HeteroData):
"""Test attribute builder for AreaWeights with invalid radius."""
with pytest.raises(AssertionError):
node_attr_builder = MaskedPlanarAreaWeights(mask_node_attr_name="nonexisting")
node_attr_builder.compute(graph_with_nodes, "test_nodes")