Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7426210
refactor knn
JPXKQX Jun 19, 2025
c3b6a34
abstract masking to a new class
JPXKQX Jun 19, 2025
411ae46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2025
b9a7e9f
feat: implement Reversed edge builders
JPXKQX Jun 19, 2025
8bdb911
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2025
72678ec
pre-commit
JPXKQX Jun 19, 2025
bb6af94
Merge branch 'fix/undo-masking-torchcluster' of https://github.com/ec…
JPXKQX Jun 19, 2025
a081395
fix: method name
JPXKQX Jun 20, 2025
1d04aad
types
JPXKQX Jun 20, 2025
8390d73
fix: avoid recursion. Call super().method
JPXKQX Jun 20, 2025
b132e59
refactor
JPXKQX Jun 20, 2025
9cfe396
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2025
8c0dfaa
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 1, 2025
1cbc3e9
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 8, 2025
8113e8d
added docstrings
JPXKQX Jul 11, 2025
b4facd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2025
cc6a996
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 11, 2025
90fe710
abstract get_unmasking_mapping()
JPXKQX Jul 11, 2025
b94a39e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2025
84c638d
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 14, 2025
4c90d21
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 16, 2025
fa1382b
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 16, 2025
25537fa
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 17, 2025
87a8be0
update schemas
JPXKQX Jul 17, 2025
1254e39
add entrypoint to edges/__init__.py
JPXKQX Jul 17, 2025
5f964bd
docs: expand docs
JPXKQX Jul 17, 2025
3c864b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2025
957a934
tests: added
JPXKQX Jul 17, 2025
fb6dd76
Merge branch 'fix/undo-masking-torchcluster' of https://github.com/ec…
JPXKQX Jul 17, 2025
c623bfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2025
ebed06f
use torch_geom imports for knn and radius
JPXKQX Jul 24, 2025
d5c0637
fix: undo edge _index
JPXKQX Jul 24, 2025
aa5429d
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 24, 2025
4abbefa
add typo
JPXKQX Jul 24, 2025
1153282
Merge branch 'fix/undo-masking-torchcluster' of https://github.com/ec…
JPXKQX Jul 24, 2025
fff6023
fix: cutoff
JPXKQX Jul 24, 2025
af4c2c9
Update graphs/src/anemoi/graphs/edges/builders/cutoff.py
JPXKQX Jul 28, 2025
15f47d1
Update graphs/docs/graphs/edges/cutoff.rst
JPXKQX Jul 28, 2025
487653f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2025
3785de1
Merge branch 'main' into fix/undo-masking-torchcluster
JPXKQX Jul 28, 2025
a836047
Update graphs/src/anemoi/graphs/edges/builders/cutoff.py
JPXKQX Jul 28, 2025
97a34ce
remove path
JPXKQX Jul 28, 2025
8948e62
update schema
JPXKQX Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions graphs/docs/graphs/edges/cutoff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ YAML configuration:
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges
cutoff_factor: 0.6
# max_num_neighbours: 64

The optional argument ``max_num_neighbours`` (default: 64) can be set to
limit the maximum number of neighbours each node can connect to.

.. note::

Expand All @@ -52,3 +56,27 @@ YAML configuration:
will be the lowest value without isolated nodes. This optimal value
depends on the node distribution, so it is recommended to tune it for
each case.

#########################
Reversed Cut-off radius
#########################

The reversed cut-off method (``ReversedCutOffEdges``) is similar to the
standard cut-off method, but instead establishes connections based on
the neighbourhood of each source node. The role of source and target
nodes is similarly reversed in :math:`\text{nodes_reference_dist}`.
Given two sets of nodes, (`source`, `target`), the
``ReversedCutOffEdges`` method connects all sources nodes to all target
nodes within a cut-off radius.

To use this method to create your connections, you can use the following
YAML configuration:

.. code:: yaml

edges:
- source_name: source
target_name: destination
edge_builders:
- _target_: anemoi.graphs.edges.ReversedCutOffEdges
cutoff_factor: 0.6
23 changes: 23 additions & 0 deletions graphs/docs/graphs/edges/knn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,26 @@ YAML configuration:

The ``KNNEdges`` method is recommended for the decoder edges, to
connect all target nodes with the surrounding source nodes.

###############################
Reversed K-Nearest Neighbours
###############################

The reversed k-nearest neighbours (``ReversedKNNEdges``) method is
similar to the standard KNN method, but instead establishes connections
based on the nearest neighbours of each source node. Given two sets of
nodes, (`source`, `target`), the ``ReversedKNNEdges`` method connects
all source nodes to their ``num_nearest_neighbours`` nearest target
nodes.

To use this method to build your connections, you can use the following
YAML configuration:

.. code:: yaml

edges:
- source_name: source
target_name: target
edge_builders:
- _target_: anemoi.graphs.edges.ReversedKNNEdges
num_nearest_neighbours: 3
2 changes: 1 addition & 1 deletion graphs/docs/modules/edge_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

.. automodule:: anemoi.graphs.edges.builder
:members:
:exclude-members: BaseEdgeBuilder,NodeMaskingMixin
:exclude-members: BaseEdgeBuilder,BaseDistanceEdgeBuilders,NodeMaskingMixin
:no-undoc-members:
:show-inheritance:
4 changes: 4 additions & 0 deletions graphs/src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@
# nor does it submit to any jurisdiction.

from .builders.cutoff import CutOffEdges
from .builders.cutoff import ReversedCutOffEdges
from .builders.icon import ICONTopologicalDecoderEdges
from .builders.icon import ICONTopologicalEncoderEdges
from .builders.icon import ICONTopologicalProcessorEdges
from .builders.knn import KNNEdges
from .builders.knn import ReversedKNNEdges
from .builders.multi_scale import MultiScaleEdges

__all__ = [
"KNNEdges",
"CutOffEdges",
"MultiScaleEdges",
"ReversedCutOffEdges",
"ReversedKNNEdges",
"ICONTopologicalProcessorEdges",
"ICONTopologicalEncoderEdges",
"ICONTopologicalDecoderEdges",
Expand Down
46 changes: 46 additions & 0 deletions graphs/src/anemoi/graphs/edges/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
import time
from abc import ABC
from abc import abstractmethod
from importlib.util import find_spec

import numpy as np
import torch
from hydra.utils import instantiate
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs.edges.builders.masking import NodeMaskingMixin
from anemoi.graphs.utils import concat_edges
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

TORCH_CLUSTER_AVAILABLE = find_spec("torch_cluster") is not None


class BaseEdgeBuilder(ABC):
"""Base class for edge builders."""
Expand Down Expand Up @@ -137,3 +142,44 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
LOGGER.debug("Time to register edge attribute (%s): %.2f s", self.__class__.__name__, t1 - t0)

return graph


class BaseDistanceEdgeBuilders(BaseEdgeBuilder, NodeMaskingMixin, ABC):
"""Base class for edge builders based on distance."""

@abstractmethod
def _compute_edge_index_pyg(self, source_coords: NodeStorage, target_coords: NodeStorage) -> np.ndarray: ...

@abstractmethod
def _compute_adj_matrix_sklearn(self, source_coords: NodeStorage, target_coords: NodeStorage) -> np.ndarray: ...

def compute_edge_index(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> torch.Tensor:
"""Compute the edge indices.

Parameters
----------
source_nodes : NodeStorage
The source nodes.
target_nodes : NodeStorage
The target nodes.

Returns
-------
torch.Tensor of shape (2, num_edges)
Indices of source and target nodes connected by an edge.
"""
source_coords, target_coords = self.get_cartesian_node_coordinates(source_nodes, target_nodes)

if TORCH_CLUSTER_AVAILABLE:
edge_index = self._compute_edge_index_pyg(source_coords, target_coords)
edge_index = self.undo_masking_edge_index(edge_index, source_nodes, target_nodes)
else:
LOGGER.warning(
"The 'torch-cluster' library is not installed. Installing 'torch-cluster' can significantly improve "
"performance for graph creation. You can install it using 'pip install torch-cluster'."
)
adj_matrix = self._compute_adj_matrix_sklearn(source_coords, target_coords)
adj_matrix = self.undo_masking_adj_matrix(adj_matrix, source_nodes, target_nodes)
edge_index = torch.from_numpy(np.stack([adj_matrix.col, adj_matrix.row], axis=0))

return edge_index
145 changes: 106 additions & 39 deletions graphs/src/anemoi/graphs/edges/builders/cutoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from __future__ import annotations

import logging
import warnings
from importlib.util import find_spec

import numpy as np
Expand All @@ -19,10 +18,10 @@
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage
from torch_geometric.nn import radius

from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.edges.builders.base import BaseEdgeBuilder
from anemoi.graphs.edges.builders.masking import NodeMaskingMixin
from anemoi.graphs.edges.builders.base import BaseDistanceEdgeBuilders
from anemoi.graphs.utils import get_grid_reference_distance

TORCH_CLUSTER_AVAILABLE = find_spec("torch_cluster") is not None
Expand All @@ -31,9 +30,11 @@
LOGGER = logging.getLogger(__name__)


class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin):
class CutOffEdges(BaseDistanceEdgeBuilders):
"""Computes cut-off based edges and adds them to the graph.

It uses as reference the target nodes.

Attributes
----------
source_name : str
Expand Down Expand Up @@ -76,7 +77,32 @@ def __init__(
self.cutoff_factor = cutoff_factor
self.max_num_neighbours = max_num_neighbours

def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None = None) -> float:
@staticmethod
def get_reference_distance(nodes: NodeStorage, mask_attr_name: torch.Tensor | None = None) -> float:
"""Compute the reference distance.

Parameters
----------
nodes : NodeStorage
The nodes.
mask_attr_name : str
The mask attribute name.

Returns
-------
float
The nodes reference distance.
"""
if mask_attr_name is not None:
# If masking nodes, we have to recompute the grid reference distance only over the masked nodes
mask = nodes[mask_attr_name]
_grid_reference_distance = get_grid_reference_distance(nodes.x, mask)
else:
_grid_reference_distance = nodes["_grid_reference_distance"]

return _grid_reference_distance

def get_cutoff_radius(self, graph: HeteroData):
"""Compute the cut-off radius.

The cut-off radius is computed as the product of the target nodes
Expand All @@ -86,38 +112,27 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: torch.Tensor | None =
----------
graph : HeteroData
The graph.
mask_attr : torch.Tensor
The mask attribute.

Returns
-------
float
The cut-off radius.
"""
target_nodes = graph[self.target_name]
if mask_attr is not None:
# If masking target nodes, we have to recompute the grid reference distance only over the masked nodes
mask = target_nodes[mask_attr]
target_grid_reference_distance = get_grid_reference_distance(target_nodes.x, mask)
else:
target_grid_reference_distance = target_nodes["_grid_reference_distance"]

radius = target_grid_reference_distance * self.cutoff_factor
return radius
reference_dist = CutOffEdges.get_reference_distance(
graph[self.target_name], mask_attr_name=self.target_mask_attr_name
)
return reference_dist * self.cutoff_factor

def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
"""Prepare node information and get source and target nodes."""
self.radius = self.get_cutoff_radius(graph)
return super().prepare_node_data(graph)

def _compute_edge_index_pyg(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> torch.Tensor:
from torch_cluster.radius import radius

source_coords, target_coords = self.get_cartesian_node_coordinates(source_nodes, target_nodes)

def _compute_edge_index_pyg(self, source_coords: torch.Tensor, target_coords: torch.Tensor) -> torch.Tensor:
edge_index = radius(source_coords, target_coords, r=self.radius, max_num_neighbors=self.max_num_neighbours)
edge_index = torch.flip(edge_index, [0])

return torch.flip(edge_index, [0])
return edge_index

def _crop_to_max_num_neighbours(self, adjmat):
"""Remove neighbors exceeding the maximum allowed limit."""
Expand All @@ -144,19 +159,15 @@ def _crop_to_max_num_neighbours(self, adjmat):
# Define the new sparse matrix
return coo_matrix((adjmat.data[mask], (adjmat.row[mask], adjmat.col[mask])), shape=adjmat.shape)

def _compute_edge_index_sklearn(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> torch.Tensor:
source_coords, target_coords = self.get_cartesian_node_coordinates(source_nodes, target_nodes)
def _compute_adj_matrix_sklearn(self, source_coords: torch.Tensor, target_coords: torch.Tensor) -> torch.Tensor:
nearest_neighbour = NearestNeighbors(metric="euclidean", n_jobs=4)
nearest_neighbour.fit(source_coords.cpu())
adj_matrix = nearest_neighbour.radius_neighbors_graph(
target_coords.cpu(), radius=self.radius, mode="distance"
).tocoo()

adj_matrix = self._crop_to_max_num_neighbours(adj_matrix)
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)
edge_index = torch.from_numpy(np.stack([adj_matrix.col, adj_matrix.row], axis=0))

return edge_index
return adj_matrix

def compute_edge_index(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> torch.Tensor:
"""Get the adjacency matrix for the cut-off method.
Expand All @@ -179,16 +190,72 @@ def compute_edge_index(self, source_nodes: NodeStorage, target_nodes: NodeStorag
self.source_name,
self.target_name,
)
return super().compute_edge_index(source_nodes=source_nodes, target_nodes=target_nodes)

if TORCH_CLUSTER_AVAILABLE:
edge_index = self._compute_edge_index_pyg(source_nodes, target_nodes)
else:
warnings.warn(
"The 'torch-cluster' library is not installed. Installing 'torch-cluster' can significantly improve "
"performance for graph creation. You can install it using 'pip install torch-cluster'.",
UserWarning,
)

edge_index = self._compute_edge_index_sklearn(source_nodes, target_nodes)
class ReversedCutOffEdges(CutOffEdges):
"""Computes cut-off based edges and adds them to the graph.

return edge_index
It uses as reference the source nodes.

Attributes
----------
source_name : str
The name of the source nodes.
target_name : str
The name of the target nodes.
cutoff_factor : float
Factor to multiply the grid reference distance to get the cut-off radius.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.
max_num_neighbours : int
The maximum number of nearest neighbours to consider when building edges.

Methods
-------
register_edges(graph)
Register the edges in the graph.
register_attributes(graph, config)
Register attributes in the edges of the graph.
update_graph(graph, attrs_config)
Update the graph with the edges.
"""

def get_cartesian_node_coordinates(
self, source_nodes: NodeStorage, target_nodes: NodeStorage
) -> tuple[torch.Tensor, torch.Tensor]:
source_coords, target_coords = super().get_cartesian_node_coordinates(source_nodes, target_nodes)
return target_coords, source_coords

def get_cutoff_radius(self, graph: HeteroData):
"""Compute the cut-off radius.

The cut-off radius is computed as the product of the target nodes
reference distance and the cut-off factor.

Parameters
----------
graph : HeteroData
The graph.

Returns
-------
float
The cut-off radius.
"""
reference_dist = CutOffEdges.get_reference_distance(
graph[self.source_name], mask_attr_name=self.source_mask_attr_name
)
return reference_dist * self.cutoff_factor

def undo_masking_adj_matrix(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage):
adj_matrix = adj_matrix.T
return super().undo_masking_adj_matrix(adj_matrix, source_nodes, target_nodes)

def undo_masking_edge_index(
self, edge_index: torch.Tensor, source_nodes: NodeStorage, target_nodes: NodeStorage
) -> torch.Tensor:
edge_index = torch.flip(edge_index, [0])
return super().undo_masking_edge_index(edge_index, source_nodes, target_nodes)
Loading
Loading