Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
460 changes: 460 additions & 0 deletions models/src/anemoi/models/layers/graph_providers.py

Large diffs are not rendered by default.

565 changes: 258 additions & 307 deletions models/src/anemoi/models/layers/mapper.py

Large diffs are not rendered by default.

76 changes: 22 additions & 54 deletions models/src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
from torch.distributed.distributed_c10d import ProcessGroup
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData
from torch_geometric.typing import Adj

from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.khop_edges import sort_edges_1hop_sharding
Expand All @@ -26,8 +26,6 @@
from anemoi.models.layers.block import GraphTransformerProcessorBlock
from anemoi.models.layers.block import PointWiseMLPProcessorBlock
from anemoi.models.layers.block import TransformerProcessorBlock
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.mapper import GraphEdgeMixin
from anemoi.models.layers.utils import load_layer_kernels
from anemoi.utils.config import DotDict

Expand Down Expand Up @@ -260,6 +258,8 @@ def forward(
x: Tensor,
batch_size: int,
shard_shapes: list[list[int]],
edge_attr: Optional[Tensor] = None,
edge_index: Optional[Adj] = None,
model_comm_group: Optional[ProcessGroup] = None,
*args,
**kwargs,
Expand All @@ -275,7 +275,7 @@ def forward(
return x


class GNNProcessor(GraphEdgeMixin, BaseProcessor):
class GNNProcessor(BaseProcessor):
"""GNN Processor."""

def __init__(
Expand All @@ -285,11 +285,7 @@ def __init__(
num_layers: int,
num_chunks: int,
mlp_extra_layers: int,
trainable_size: int,
src_grid_size: int,
dst_grid_size: int,
sub_graph: HeteroData,
sub_graph_edge_attributes: list[str],
edge_dim: int,
cpu_offload: bool = False,
layer_kernels: DotDict,
**kwargs,
Expand All @@ -304,18 +300,10 @@ def __init__(
Number of channels
num_chunks: int
Number of chunks in processor
mlp_extra_layers : int, optional
mlp_extra_layers : int
Number of extra layers in MLP
trainable_size : int
Size of trainable tensor
src_grid_size : int
Source grid size
dst_grid_size : int
Destination grid size
sub_graph : HeteroData
Graph for sub graph in GNN
sub_graph_edge_attributes : list[str]
Sub graph edge attributes
edge_dim : int
Edge feature dimension
cpu_offload : bool
Whether to offload processing to CPU, by default False
layer_kernels : DotDict
Expand All @@ -332,11 +320,7 @@ def __init__(
layer_kernels=layer_kernels,
)

self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)

self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0])

kwargs = {
kwargs_build = {
"mlp_extra_layers": mlp_extra_layers,
"layer_kernels": self.layer_factory,
"edge_dim": None,
Expand All @@ -347,15 +331,15 @@ def __init__(
in_channels=num_channels,
out_channels=num_channels,
num_chunks=1,
**kwargs,
**kwargs_build,
)

kwargs["edge_dim"] = self.edge_dim # Edge dim for first layer
kwargs_build["edge_dim"] = edge_dim # Edge dim for first layer
self.proc[0] = GraphConvProcessorBlock(
in_channels=num_channels,
out_channels=num_channels,
num_chunks=1,
**kwargs,
**kwargs_build,
)

self.offload_layers(cpu_offload)
Expand All @@ -365,13 +349,14 @@ def forward(
x: Tensor,
batch_size: int,
shard_shapes: list[list[int]],
edge_attr: Tensor,
edge_index: Adj,
model_comm_group: Optional[ProcessGroup] = None,
*args,
**kwargs,
) -> Tensor:
shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels)
edge_attr = self.trainable(self.edge_attr, batch_size)
edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size)

target_nodes = sum(x[0] for x in shape_nodes)
edge_attr, edge_index, shapes_edge_attr, shapes_edge_idx = sort_edges_1hop_sharding(
target_nodes,
Expand All @@ -389,7 +374,7 @@ def forward(
return x


class GraphTransformerProcessor(GraphEdgeMixin, BaseProcessor):
class GraphTransformerProcessor(BaseProcessor):
"""Processor."""

def __init__(
Expand All @@ -400,11 +385,7 @@ def __init__(
num_chunks: int,
num_heads: int,
mlp_hidden_ratio: int,
trainable_size: int,
src_grid_size: int,
dst_grid_size: int,
sub_graph: HeteroData,
sub_graph_edge_attributes: list[str],
edge_dim: int,
qk_norm: bool = False,
cpu_offload: bool = False,
layer_kernels: DotDict,
Expand All @@ -424,16 +405,8 @@ def __init__(
Number of heads in transformer
mlp_hidden_ratio: int
Ratio of mlp hidden dimension to embedding dimension
trainable_size : int
Size of trainable tensor
src_grid_size : int
Source grid size
dst_grid_size : int
Destination grid size
sub_graph : HeteroData
Graph for sub graph in GNN
sub_graph_edge_attributes : list[str]
Sub graph edge attributes
edge_dim : int
Edge feature dimension
qk_norm: bool, optional
Normalize query and key, by default False
cpu_offload : bool, optional
Expand All @@ -452,19 +425,15 @@ def __init__(
layer_kernels=layer_kernels,
)

self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)

self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0])

self.build_layers(
GraphTransformerProcessorBlock,
in_channels=num_channels,
hidden_dim=(mlp_hidden_ratio * num_channels),
out_channels=num_channels,
num_heads=num_heads,
edge_dim=self.edge_dim,
layer_kernels=self.layer_factory,
qk_norm=qk_norm,
edge_dim=edge_dim,
)

self.offload_layers(cpu_offload)
Expand All @@ -474,16 +443,15 @@ def forward(
x: Tensor,
batch_size: int,
shard_shapes: list[list[int]],
edge_attr: Tensor,
edge_index: Adj,
model_comm_group: Optional[ProcessGroup] = None,
*args,
**kwargs,
) -> Tensor:
size = sum(x[0] for x in shard_shapes)

shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels)
edge_attr = self.trainable(self.edge_attr, batch_size)

edge_index = self._expand_edges(self.edge_index_base, self.edge_inc, batch_size)

shapes_edge_attr = get_shard_shapes(edge_attr, 0, model_comm_group)
edge_attr = shard_tensor(edge_attr, 0, shapes_edge_attr, model_comm_group)
Expand Down
15 changes: 8 additions & 7 deletions models/src/anemoi/models/layers/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from anemoi.models.distributed.graph import gather_channels
from anemoi.models.distributed.graph import shard_channels
from anemoi.models.distributed.shapes import apply_shard_shapes
from anemoi.models.layers.sparse_projector import build_sparse_projector
from anemoi.models.layers.graph_providers import ProjectionGraphProvider
from anemoi.models.layers.sparse_projector import SparseProjector


class BaseResidualConnection(nn.Module, ABC):
Expand Down Expand Up @@ -134,24 +135,24 @@ def __init__(
edge_weight_attribute,
)

self.project_down = build_sparse_projector(
self.provider_down = ProjectionGraphProvider(
graph=graph,
edges_name=down_edges,
edge_weight_attribute=edge_weight_attribute,
src_node_weight_attribute=src_node_weight_attribute,
file_path=truncation_down_file_path,
autocast=autocast,
)

self.project_up = build_sparse_projector(
self.provider_up = ProjectionGraphProvider(
graph=graph,
edges_name=up_edges,
edge_weight_attribute=edge_weight_attribute,
src_node_weight_attribute=src_node_weight_attribute,
file_path=truncation_up_file_path,
autocast=autocast,
)

self.projector = SparseProjector(autocast=autocast)

def _get_edges_name(
self,
graph,
Expand Down Expand Up @@ -187,8 +188,8 @@ def forward(self, x: torch.Tensor, grid_shard_shapes=None, model_comm_group=None

x = einops.rearrange(x, "batch ensemble grid features -> (batch ensemble) grid features")
x = self._to_channel_shards(x, shard_shapes, model_comm_group)
x = self.project_down(x)
x = self.project_up(x)
x = self.projector(x, self.provider_down.get_edges(device=x.device))
x = self.projector(x, self.provider_up.get_edges(device=x.device))
x = self._to_grid_shards(x, shard_shapes, model_comm_group)
x = einops.rearrange(x, "(batch ensemble) grid features -> batch ensemble grid features", batch=batch_size)

Expand Down
Loading