diff --git a/models/src/anemoi/models/layers/mapper/__init__.py b/models/src/anemoi/models/layers/mapper/__init__.py new file mode 100644 index 000000000..f39966cff --- /dev/null +++ b/models/src/anemoi/models/layers/mapper/__init__.py @@ -0,0 +1,24 @@ +# (C) Copyright 2024- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from anemoi.models.layers.mapper.dynamic import DynamicGraphTransformerBackwardMapper +from anemoi.models.layers.mapper.dynamic import DynamicGraphTransformerForwardMapper +from anemoi.models.layers.mapper.static import GNNBackwardMapper +from anemoi.models.layers.mapper.static import GNNForwardMapper +from anemoi.models.layers.mapper.static import GraphTransformerBackwardMapper +from anemoi.models.layers.mapper.static import GraphTransformerForwardMapper + +__all__ = [ + "DynamicGraphTransformerBackwardMapper", + "DynamicGraphTransformerForwardMapper", + "GraphTransformerBackwardMapper", + "GraphTransformerForwardMapper", + "GNNBackwardMapper", + "GNNForwardMapper", +] diff --git a/models/src/anemoi/models/layers/mapper/base.py b/models/src/anemoi/models/layers/mapper/base.py new file mode 100644 index 000000000..0870eb694 --- /dev/null +++ b/models/src/anemoi/models/layers/mapper/base.py @@ -0,0 +1,122 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import os +from abc import ABC +from typing import Optional + +from torch import Tensor +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper + +from anemoi.models.distributed.graph import gather_tensor +from anemoi.models.distributed.graph import shard_tensor +from anemoi.models.distributed.shapes import change_channels_in_shape +from anemoi.models.layers.utils import load_layer_kernels +from anemoi.utils.config import DotDict + +LOGGER = logging.getLogger(__name__) + +# Number of chunks used in inference (https://github.com/ecmwf/anemoi-core/pull/406) +NUM_CHUNKS_INFERENCE = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS", "1")) +NUM_CHUNKS_INFERENCE_MAPPER = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS_MAPPER", NUM_CHUNKS_INFERENCE)) + + +class BaseMapper(nn.Module, ABC): + """Base Mapper from souce dimension to destination dimension.""" + + def __init__( + self, + *, + in_channels_src: int, + in_channels_dst: int, + hidden_dim: int, + out_channels_dst: Optional[int] = None, + cpu_offload: bool = False, + layer_kernels: DotDict, + **kwargs, + ) -> None: + """Initialize BaseMapper.""" + super().__init__() + + self.in_channels_src = in_channels_src + self.in_channels_dst = in_channels_dst + self.hidden_dim = hidden_dim + self.out_channels_dst = out_channels_dst + self.layer_factory = load_layer_kernels(layer_kernels) + self.activation = self.layer_factory.Activation() + + self.proc = NotImplemented + + self.offload_layers(cpu_offload) + + def offload_layers(self, cpu_offload): + if cpu_offload: + self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) + + def pre_process( + self, x, shard_shapes, model_comm_group=None, x_src_is_sharded=False, x_dst_is_sharded=False + ) -> tuple[Tensor, Tensor, tuple[int], tuple[int]]: + """Pre-processing for the Mappers. + + Splits the tuples into src and dst nodes and shapes as the base operation. + + Parameters + ---------- + x : Tuple[Tensor] + Data containing source and destination nodes and edges. + shard_shapes : Tuple[Tuple[int], Tuple[int]] + Shapes of the sharded source and destination nodes. + model_comm_group : ProcessGroup + Groups which GPUs work together on one model instance + + Return + ------ + Tuple[Tensor, Tensor, Tuple[int], Tuple[int]] + Source nodes, destination nodes, sharded source node shapes, sharded destination node shapes + """ + shapes_src, shapes_dst = shard_shapes + x_src, x_dst = x + return x_src, x_dst, shapes_src, shapes_dst + + def post_process(self, x_dst, shapes_dst, model_comm_group=None, keep_x_dst_sharded=False) -> Tensor: + """Post-processing for the mapper.""" + return x_dst + + +class BackwardMapperPostProcessMixin: + """Post-processing for Backward Mapper from hidden -> data.""" + + def post_process(self, x_dst, shapes_dst, model_comm_group=None, keep_x_dst_sharded=False): + x_dst = self.node_data_extractor(x_dst) + if not keep_x_dst_sharded: + x_dst = gather_tensor( + x_dst, 0, change_channels_in_shape(shapes_dst, self.out_channels_dst), model_comm_group + ) + return x_dst + + +class ForwardMapperPreProcessMixin: + """Pre-processing for Forward Mapper from data -> hidden.""" + + def pre_process(self, x, shard_shapes, model_comm_group=None, x_src_is_sharded=False, x_dst_is_sharded=False): + x_src, x_dst, shapes_src, shapes_dst = super().pre_process( + x, shard_shapes, model_comm_group, x_src_is_sharded, x_dst_is_sharded + ) + if not x_src_is_sharded: + x_src = shard_tensor(x_src, 0, shapes_src, model_comm_group) + if not x_dst_is_sharded: + x_dst = shard_tensor(x_dst, 0, shapes_dst, model_comm_group) + x_src = self.emb_nodes_src(x_src) + x_dst = self.emb_nodes_dst(x_dst) + shapes_src = change_channels_in_shape(shapes_src, self.hidden_dim) + shapes_dst = change_channels_in_shape(shapes_dst, self.hidden_dim) + return x_src, x_dst, shapes_src, shapes_dst diff --git a/models/src/anemoi/models/layers/mapper/dynamic.py b/models/src/anemoi/models/layers/mapper/dynamic.py new file mode 100644 index 000000000..5d2e56558 --- /dev/null +++ b/models/src/anemoi/models/layers/mapper/dynamic.py @@ -0,0 +1,284 @@ +# (C) Copyright 2024- ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import torch +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.data import HeteroData +from torch_geometric.typing import PairTensor + +from anemoi.models.distributed.graph import shard_tensor +from anemoi.models.distributed.shapes import change_channels_in_shape +from anemoi.models.distributed.shapes import get_shard_shapes +from anemoi.models.layers.block import GraphTransformerMapperBlock +from anemoi.models.layers.mapper.base import BackwardMapperPostProcessMixin +from anemoi.models.layers.mapper.base import BaseMapper +from anemoi.models.layers.mapper.base import ForwardMapperPreProcessMixin +from anemoi.utils.config import DotDict + +LOGGER = logging.getLogger(__name__) + + +class DynamicGraphTransformerBaseMapper(BaseMapper): + """Dynamic Graph Transformer Base Mapper from hidden -> data or data -> hidden.""" + + def __init__( + self, + in_channels_src: int = 0, + in_channels_dst: int = 0, + hidden_dim: int = 128, + out_channels_dst: Optional[int] = None, + subgraph_edge_attributes: Optional[list] = [], + subgraph_edge_index_name: str = "edge_index", + layer_kernels: DotDict = None, + num_chunks: int = 1, + cpu_offload: bool = False, + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + edge_dim: int = 0, + ) -> None: + """Initialize DynamicGraphTransformerBaseMapper. + + Parameters + ---------- + in_channels_src : int + Input channels of the source node + in_channels_dst : int + Input channels of the destination node + hidden_dim : int + Hidden dimension + num_heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + subgraph_edge_attributes: list[str] + Names of edge attributes to consider + subgraph_edge_index_name: str + Name of the edge index attribute in the graph. Defaults to "edge_index". + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + out_channels_dst : Optional[int], optional + Output channels of the destination node, by default None + edge_dim : int, optional + The dimension of the edge attributes + """ + super().__init__( + in_channels_src=in_channels_src, + in_channels_dst=in_channels_dst, + hidden_dim=hidden_dim, + out_channels_dst=out_channels_dst, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + layer_kernels=layer_kernels, + ) + self.edge_attribute_names = subgraph_edge_attributes + self.edge_index_name = subgraph_edge_index_name + + self.proc = GraphTransformerMapperBlock( + in_channels=hidden_dim, + hidden_dim=mlp_hidden_ratio * hidden_dim, + out_channels=hidden_dim, + num_heads=num_heads, + edge_dim=edge_dim, + num_chunks=num_chunks, + layer_kernels=self.layer_factory, + ) + + self.offload_layers(cpu_offload) + + self.emb_nodes_dst = ( + nn.Linear(self.in_channels_dst, self.hidden_dim) + if self.in_channels_dst != self.hidden_dim + else nn.Identity() + ) + + def forward( + self, + x: PairTensor, + subgraph: HeteroData, + batch_size: int, + shard_shapes: tuple[tuple[int], tuple[int]], + model_comm_group: Optional[ProcessGroup] = None, + ) -> PairTensor: + size = (sum(x[0] for x in shard_shapes[0]), sum(x[0] for x in shard_shapes[1])) + edge_index = subgraph[self.edge_index_name].to(torch.int64) + edge_attr = torch.cat([subgraph[attr] for attr in self.edge_attribute_names], axis=1) + + shapes_edge_attr = get_shard_shapes(edge_attr, 0, model_comm_group) + edge_attr = shard_tensor(edge_attr, 0, shapes_edge_attr, model_comm_group) + + x_src, x_dst, shapes_src, shapes_dst = self.pre_process(x, shard_shapes, model_comm_group) + + (x_src, x_dst), edge_attr = self.proc( + (x_src, x_dst), + edge_attr, + edge_index, + (shapes_src, shapes_dst, shapes_edge_attr), + batch_size=batch_size, + model_comm_group=model_comm_group, + size=size, + ) + + x_dst = self.post_process(x_dst, shapes_dst, model_comm_group) + + return x_dst + + +class DynamicGraphTransformerForwardMapper(ForwardMapperPreProcessMixin, DynamicGraphTransformerBaseMapper): + """Dynamic Graph Transformer Mapper from data -> hidden.""" + + def __init__( + self, + in_channels_src: int = 0, + in_channels_dst: int = 0, + hidden_dim: int = 128, + out_channels_dst: Optional[int] = None, + subgraph_edge_attributes: Optional[list] = [], + subgraph_edge_index_name: str = "edge_index", + layer_kernels: DotDict = None, + num_chunks: int = 1, + cpu_offload: bool = False, + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + edge_dim: int = 0, + **kwargs, + ) -> None: + """Initialize DynamicGraphTransformerForwardMapper. + + Parameters + ---------- + in_channels_src : int + Input channels of the source node + in_channels_dst : int + Input channels of the destination node + hidden_dim : int + Hidden dimension + num_heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + subgraph_edge_attributes: list[str] + Names of edge attributes to consider + subgraph_edge_index_name: str + Name of the edge index attribute in the graph. Defaults to "edge_index". + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + out_channels_dst : Optional[int], optional + Output channels of the destination node, by default None + edge_dim: int, optional + Dimension of the edge attributes + """ + super().__init__( + in_channels_src, + in_channels_dst, + hidden_dim, + out_channels_dst=out_channels_dst, + subgraph_edge_attributes=subgraph_edge_attributes, + subgraph_edge_index_name=subgraph_edge_index_name, + layer_kernels=layer_kernels, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + edge_dim=edge_dim, + ) + + self.emb_nodes_src = ( + nn.Linear(self.in_channels_src, self.hidden_dim) + if self.in_channels_src != self.hidden_dim + else nn.Identity() + ) + + def forward( + self, + x: PairTensor, + subgraph: HeteroData, + batch_size: int, + shard_shapes: tuple[tuple[int], tuple[int]], + model_comm_group: Optional[ProcessGroup] = None, + ) -> PairTensor: + x_dst = super().forward( + x, subgraph, batch_size=batch_size, shard_shapes=shard_shapes, model_comm_group=model_comm_group + ) + return x[0], x_dst + + +class DynamicGraphTransformerBackwardMapper(BackwardMapperPostProcessMixin, DynamicGraphTransformerBaseMapper): + """Dynamic Graph Transformer Mapper from hidden -> data.""" + + def __init__( + self, + in_channels_src: int = 0, + in_channels_dst: int = 0, + hidden_dim: int = 128, + out_channels_dst: Optional[int] = None, + subgraph_edge_attributes: Optional[list] = [], + subgraph_edge_index_name: str = "edge_index", + layer_kernels: DotDict = None, + num_chunks: int = 1, + cpu_offload: bool = False, + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + edge_dim: int = 0, + **kwargs, + ) -> None: + """Initialize DynamicGraphTransformerBackwardMapper. + + Parameters + ---------- + in_channels_src : int + Input channels of the source node + in_channels_dst : int + Input channels of the destination node + hidden_dim : int + Hidden dimension + num_heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + subgraph_edge_attributes: list[str] + Names of edge attributes to consider + subgraph_edge_index_name: str + Name of the edge index attribute in the graph. Defaults to "edge_index". + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + out_channels_dst : Optional[int], optional + Output channels of the destination node, by default None + edge_dim: int, optional + Dimension of the edge attributes + """ + super().__init__( + in_channels_src, + in_channels_dst, + hidden_dim, + out_channels_dst=out_channels_dst, + subgraph_edge_attributes=subgraph_edge_attributes, + subgraph_edge_index_name=subgraph_edge_index_name, + layer_kernels=layer_kernels, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + edge_dim=edge_dim, + ) + + self.node_data_extractor = nn.Sequential( + nn.LayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst) + ) + + def pre_process(self, x, shard_shapes, model_comm_group=None): + x_src, x_dst, shapes_src, shapes_dst = super().pre_process(x, shard_shapes, model_comm_group) + shapes_src = change_channels_in_shape(shapes_src, self.hidden_dim) + x_dst = shard_tensor(x_dst, 0, shapes_dst, model_comm_group) + x_dst = self.emb_nodes_dst(x_dst) + shapes_dst = change_channels_in_shape(shapes_dst, self.hidden_dim) + return x_src, x_dst, shapes_src, shapes_dst diff --git a/models/src/anemoi/models/layers/mapper.py b/models/src/anemoi/models/layers/mapper/static.py similarity index 92% rename from models/src/anemoi/models/layers/mapper.py rename to models/src/anemoi/models/layers/mapper/static.py index cb848c394..743b1a9c9 100644 --- a/models/src/anemoi/models/layers/mapper.py +++ b/models/src/anemoi/models/layers/mapper/static.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024- Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -10,14 +10,12 @@ import logging import os -from abc import ABC from typing import Optional import numpy as np import torch from torch import Tensor from torch import nn -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper from torch.distributed.distributed_c10d import ProcessGroup from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData @@ -36,109 +34,20 @@ from anemoi.models.layers.block import GraphTransformerMapperBlock from anemoi.models.layers.block import TransformerMapperBlock from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.mapper.base import BackwardMapperPostProcessMixin +from anemoi.models.layers.mapper.base import BaseMapper +from anemoi.models.layers.mapper.base import ForwardMapperPreProcessMixin from anemoi.models.layers.mlp import MLP -from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) + # Number of chunks used in inference (https://github.com/ecmwf/anemoi-core/pull/406) NUM_CHUNKS_INFERENCE = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS", "1")) NUM_CHUNKS_INFERENCE_MAPPER = int(os.environ.get("ANEMOI_INFERENCE_NUM_CHUNKS_MAPPER", NUM_CHUNKS_INFERENCE)) -class BaseMapper(nn.Module, ABC): - """Base Mapper from souce dimension to destination dimension.""" - - def __init__( - self, - *, - in_channels_src: int, - in_channels_dst: int, - hidden_dim: int, - out_channels_dst: Optional[int] = None, - cpu_offload: bool = False, - layer_kernels: DotDict, - **kwargs, - ) -> None: - """Initialize BaseMapper.""" - super().__init__() - - self.in_channels_src = in_channels_src - self.in_channels_dst = in_channels_dst - self.hidden_dim = hidden_dim - self.out_channels_dst = out_channels_dst - self.layer_factory = load_layer_kernels(layer_kernels) - self.activation = self.layer_factory.Activation() - - self.proc = NotImplemented - - self.offload_layers(cpu_offload) - - def offload_layers(self, cpu_offload): - if cpu_offload: - self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) - - def pre_process( - self, x, shard_shapes, model_comm_group=None, x_src_is_sharded=False, x_dst_is_sharded=False - ) -> tuple[Tensor, Tensor, tuple[int], tuple[int]]: - """Pre-processing for the Mappers. - - Splits the tuples into src and dst nodes and shapes as the base operation. - - Parameters - ---------- - x : Tuple[Tensor] - Data containing source and destination nodes and edges. - shard_shapes : Tuple[Tuple[int], Tuple[int]] - Shapes of the sharded source and destination nodes. - model_comm_group : ProcessGroup - Groups which GPUs work together on one model instance - - Return - ------ - Tuple[Tensor, Tensor, Tuple[int], Tuple[int]] - Source nodes, destination nodes, sharded source node shapes, sharded destination node shapes - """ - shapes_src, shapes_dst = shard_shapes - x_src, x_dst = x - return x_src, x_dst, shapes_src, shapes_dst - - def post_process(self, x_dst, shapes_dst, model_comm_group=None, keep_x_dst_sharded=False) -> Tensor: - """Post-processing for the mapper.""" - return x_dst - - -class BackwardMapperPostProcessMixin: - """Post-processing for Backward Mapper from hidden -> data.""" - - def post_process(self, x_dst, shapes_dst, model_comm_group=None, keep_x_dst_sharded=False): - x_dst = self.node_data_extractor(x_dst) - if not keep_x_dst_sharded: - x_dst = gather_tensor( - x_dst, 0, change_channels_in_shape(shapes_dst, self.out_channels_dst), model_comm_group - ) - return x_dst - - -class ForwardMapperPreProcessMixin: - """Pre-processing for Forward Mapper from data -> hidden.""" - - def pre_process(self, x, shard_shapes, model_comm_group=None, x_src_is_sharded=False, x_dst_is_sharded=False): - x_src, x_dst, shapes_src, shapes_dst = super().pre_process( - x, shard_shapes, model_comm_group, x_src_is_sharded, x_dst_is_sharded - ) - if not x_src_is_sharded: - x_src = shard_tensor(x_src, 0, shapes_src, model_comm_group) - if not x_dst_is_sharded: - x_dst = shard_tensor(x_dst, 0, shapes_dst, model_comm_group) - x_src = self.emb_nodes_src(x_src) - x_dst = self.emb_nodes_dst(x_dst) - shapes_src = change_channels_in_shape(shapes_src, self.hidden_dim) - shapes_dst = change_channels_in_shape(shapes_dst, self.hidden_dim) - return x_src, x_dst, shapes_src, shapes_dst - - class GraphEdgeMixin: def _register_edges( self, sub_graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int diff --git a/models/src/anemoi/models/layers/processor/__init__.py b/models/src/anemoi/models/layers/processor/__init__.py new file mode 100644 index 000000000..2076bc0d6 --- /dev/null +++ b/models/src/anemoi/models/layers/processor/__init__.py @@ -0,0 +1,16 @@ +# (C) Copyright 2024- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from anemoi.models.layers.processor.dynamic import DynamicGraphTransformerProcessor +from anemoi.models.layers.processor.static import GNNProcessor +from anemoi.models.layers.processor.static import GraphTransformerProcessor +from anemoi.models.layers.processor.static import TransformerProcessor + +__all__ = ["TransformerProcessor", "GNNProcessor", "GraphTransformerProcessor", "DynamicGraphTransformerProcessor"] diff --git a/models/src/anemoi/models/layers/processor/base.py b/models/src/anemoi/models/layers/processor/base.py new file mode 100644 index 000000000..450f4e4b9 --- /dev/null +++ b/models/src/anemoi/models/layers/processor/base.py @@ -0,0 +1,74 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from abc import ABC + +from torch import Tensor +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper +from torch.utils.checkpoint import checkpoint + +from anemoi.models.layers.utils import load_layer_kernels +from anemoi.utils.config import DotDict + + +class BaseProcessor(nn.Module, ABC): + """Base Processor.""" + + def __init__( + self, + *, + num_layers: int, + num_channels: int, + num_chunks: int, + cpu_offload: bool = False, + layer_kernels: DotDict, + **kwargs, + ) -> None: + """Initialize BaseProcessor.""" + super().__init__() + + # Each Processor divides the layers into chunks that get assigned to each ProcessorChunk + self.num_chunks = num_chunks + self.num_channels = num_channels + self.chunk_size = num_layers // num_chunks + + self.layer_factory = load_layer_kernels(layer_kernels) + + assert ( + num_layers % num_chunks == 0 + ), f"Number of processor layers ({num_layers}) has to be divisible by the number of processor chunks ({num_chunks})." + + def offload_layers(self, cpu_offload): + if cpu_offload: + self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) + + def build_layers(self, processor_chunk_class, *args, **kwargs) -> None: + """Build Layers.""" + self.proc = nn.ModuleList( + [ + processor_chunk_class( + *args, + **kwargs, + ) + for _ in range(self.num_chunks) + ], + ) + + def run_layers(self, data: tuple, *args, **kwargs) -> Tensor: + """Run Layers with checkpoint.""" + for layer in self.proc: + data = checkpoint(layer, *data, *args, **kwargs, use_reentrant=False) + return data + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + """Example forward pass.""" + x = self.run_layers((x,), *args, **kwargs) + return x diff --git a/models/src/anemoi/models/layers/processor/dynamic.py b/models/src/anemoi/models/layers/processor/dynamic.py new file mode 100644 index 000000000..c14341276 --- /dev/null +++ b/models/src/anemoi/models/layers/processor/dynamic.py @@ -0,0 +1,114 @@ +# (C) Copyright 2024- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Optional + +import torch +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.data import HeteroData + +from anemoi.models.distributed.graph import shard_tensor +from anemoi.models.distributed.shapes import change_channels_in_shape +from anemoi.models.distributed.shapes import get_shard_shapes +from anemoi.models.layers.chunk import GraphTransformerProcessorChunk +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.processor.base import BaseProcessor +from anemoi.utils.config import DotDict + + +class DynamicGraphTransformerProcessor(BaseProcessor): + """Processor.""" + + def __init__( + self, + num_layers: int, + layer_kernels: DotDict, + trainable_size: int = 8, + num_channels: int = 128, + num_chunks: int = 2, + num_heads: int = 16, + mlp_hidden_ratio: int = 4, + cpu_offload: bool = False, + subgraph_edge_index_name: str = "edge_index", + subgraph_edge_attributes: Optional[list] = [], + edge_dim: int = 0, + **kwargs, + ) -> None: + """Initialize DynamicGraphTransformerProcessor. + + Parameters + ---------- + num_layers : int + Number of layers + num_channels : int + Number of channels + num_chunks : int, optional + Number of num_chunks, by default 2 + heads: int + Number of heads to use, default 16 + mlp_hidden_ratio: int + ratio of mlp hidden dimension to embedding dimension, default 4 + cpu_offload : bool, optional + Whether to offload processing to CPU, by default False + """ + super().__init__( + num_channels=num_channels, + num_layers=num_layers, + num_chunks=num_chunks, + cpu_offload=cpu_offload, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + layer_kernels=layer_kernels, + ) + self.edge_dim = edge_dim + self.edge_attribute_names = subgraph_edge_attributes + self.edge_index_name = subgraph_edge_index_name + + self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_dim) + + self.build_layers( + GraphTransformerProcessorChunk, + num_channels=num_channels, + num_layers=self.chunk_size, + layer_kernels=self.layer_factory, + num_heads=num_heads, + mlp_hidden_ratio=mlp_hidden_ratio, + edge_dim=self.edge_dim, + ) + + self.offload_layers(cpu_offload) + + def forward( + self, + x: Tensor, + subgraph: HeteroData, + batch_size: int, + shard_shapes: tuple[tuple[int], tuple[int]], + model_comm_group: Optional[ProcessGroup] = None, + *args, + **kwargs, + ) -> Tensor: + shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels) + edge_index = subgraph[self.edge_index_name].to(torch.int64) + edge_attr = torch.cat([subgraph[attr] for attr in self.edge_attribute_names], axis=1) + + shapes_edge_attr = get_shard_shapes(edge_attr, 0, model_comm_group) + edge_attr = shard_tensor(edge_attr, 0, shapes_edge_attr, model_comm_group) + + x, edge_attr = self.run_layers( + (x, edge_attr), + edge_index, + (shape_nodes, shape_nodes, shapes_edge_attr), + batch_size, + model_comm_group, + ) + + return x diff --git a/models/src/anemoi/models/layers/processor.py b/models/src/anemoi/models/layers/processor/static.py similarity index 85% rename from models/src/anemoi/models/layers/processor.py rename to models/src/anemoi/models/layers/processor/static.py index 5575a3923..754bc7826 100644 --- a/models/src/anemoi/models/layers/processor.py +++ b/models/src/anemoi/models/layers/processor/static.py @@ -8,14 +8,10 @@ # nor does it submit to any jurisdiction. -from abc import ABC from typing import Optional from torch import Tensor -from torch import nn -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper from torch.distributed.distributed_c10d import ProcessGroup -from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData from anemoi.models.distributed.graph import shard_tensor @@ -26,66 +22,11 @@ from anemoi.models.layers.chunk import GraphTransformerProcessorChunk from anemoi.models.layers.chunk import TransformerProcessorChunk from anemoi.models.layers.graph import TrainableTensor -from anemoi.models.layers.mapper import GraphEdgeMixin -from anemoi.models.layers.utils import load_layer_kernels +from anemoi.models.layers.mapper.static import GraphEdgeMixin +from anemoi.models.layers.processor.base import BaseProcessor from anemoi.utils.config import DotDict -class BaseProcessor(nn.Module, ABC): - """Base Processor.""" - - def __init__( - self, - *, - num_layers: int, - num_channels: int, - num_chunks: int, - cpu_offload: bool = False, - layer_kernels: DotDict, - **kwargs, - ) -> None: - """Initialize BaseProcessor.""" - super().__init__() - - # Each Processor divides the layers into chunks that get assigned to each ProcessorChunk - self.num_chunks = num_chunks - self.num_channels = num_channels - self.chunk_size = num_layers // num_chunks - - self.layer_factory = load_layer_kernels(layer_kernels) - - assert ( - num_layers % num_chunks == 0 - ), f"Number of processor layers ({num_layers}) has to be divisible by the number of processor chunks ({num_chunks})." - - def offload_layers(self, cpu_offload): - if cpu_offload: - self.proc = nn.ModuleList([offload_wrapper(x) for x in self.proc]) - - def build_layers(self, processor_chunk_class, *args, **kwargs) -> None: - """Build Layers.""" - self.proc = nn.ModuleList( - [ - processor_chunk_class( - *args, - **kwargs, - ) - for _ in range(self.num_chunks) - ], - ) - - def run_layers(self, data: tuple, *args, **kwargs) -> Tensor: - """Run Layers with checkpoint.""" - for layer in self.proc: - data = checkpoint(layer, *data, *args, **kwargs, use_reentrant=False) - return data - - def forward(self, x: Tensor, *args, **kwargs) -> Tensor: - """Example forward pass.""" - x = self.run_layers((x,), *args, **kwargs) - return x - - class TransformerProcessor(BaseProcessor): """Transformer Processor.""" diff --git a/models/tests/layers/mapper/test_base_mapper.py b/models/tests/layers/mapper/test_base_mapper.py index 7faf80dc8..10315a3aa 100644 --- a/models/tests/layers/mapper/test_base_mapper.py +++ b/models/tests/layers/mapper/test_base_mapper.py @@ -16,7 +16,7 @@ from torch import nn from torch_geometric.data import HeteroData -from anemoi.models.layers.mapper import BaseMapper +from anemoi.models.layers.mapper.base import BaseMapper from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/models/tests/layers/mapper/test_graphconv_mapper.py b/models/tests/layers/mapper/test_graphconv_mapper.py index 5cb8973c4..ab7409025 100644 --- a/models/tests/layers/mapper/test_graphconv_mapper.py +++ b/models/tests/layers/mapper/test_graphconv_mapper.py @@ -17,8 +17,8 @@ from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import GNNBackwardMapper -from anemoi.models.layers.mapper import GNNBaseMapper from anemoi.models.layers.mapper import GNNForwardMapper +from anemoi.models.layers.mapper.static import GNNBaseMapper from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/models/tests/layers/mapper/test_graphtransformer_mapper.py b/models/tests/layers/mapper/test_graphtransformer_mapper.py index 46bd56a07..28585ed14 100644 --- a/models/tests/layers/mapper/test_graphtransformer_mapper.py +++ b/models/tests/layers/mapper/test_graphtransformer_mapper.py @@ -17,8 +17,8 @@ from torch_geometric.data import HeteroData from anemoi.models.layers.mapper import GraphTransformerBackwardMapper -from anemoi.models.layers.mapper import GraphTransformerBaseMapper from anemoi.models.layers.mapper import GraphTransformerForwardMapper +from anemoi.models.layers.mapper.static import GraphTransformerBaseMapper from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/models/tests/layers/mapper/test_transformer_mapper.py b/models/tests/layers/mapper/test_transformer_mapper.py index 152e18a0c..b482387e0 100644 --- a/models/tests/layers/mapper/test_transformer_mapper.py +++ b/models/tests/layers/mapper/test_transformer_mapper.py @@ -13,7 +13,7 @@ from omegaconf import OmegaConf from torch_geometric.data import HeteroData -from anemoi.models.layers.mapper import TransformerBaseMapper +from anemoi.models.layers.mapper.static import TransformerBaseMapper from anemoi.models.layers.utils import load_layer_kernels diff --git a/models/tests/layers/processor/test_base_processor.py b/models/tests/layers/processor/test_base_processor.py index b8de559a0..2d0cedfe6 100644 --- a/models/tests/layers/processor/test_base_processor.py +++ b/models/tests/layers/processor/test_base_processor.py @@ -13,7 +13,7 @@ import pytest -from anemoi.models.layers.processor import BaseProcessor +from anemoi.models.layers.processor.base import BaseProcessor from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/models/tests/layers/processor/test_graphconv_processor.py b/models/tests/layers/processor/test_graphconv_processor.py index 39e51c7ba..056aa0f80 100644 --- a/models/tests/layers/processor/test_graphconv_processor.py +++ b/models/tests/layers/processor/test_graphconv_processor.py @@ -16,7 +16,7 @@ from torch_geometric.data import HeteroData from anemoi.models.layers.graph import TrainableTensor -from anemoi.models.layers.processor import GNNProcessor +from anemoi.models.layers.processor.static import GNNProcessor from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/models/tests/layers/processor/test_graphtransformer_processor.py b/models/tests/layers/processor/test_graphtransformer_processor.py index eb2adab33..651dadd2e 100644 --- a/models/tests/layers/processor/test_graphtransformer_processor.py +++ b/models/tests/layers/processor/test_graphtransformer_processor.py @@ -16,7 +16,7 @@ from torch_geometric.data import HeteroData from anemoi.models.layers.graph import TrainableTensor -from anemoi.models.layers.processor import GraphTransformerProcessor +from anemoi.models.layers.processor.static import GraphTransformerProcessor from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/models/tests/layers/processor/test_transformer_processor.py b/models/tests/layers/processor/test_transformer_processor.py index 1294f402d..10477e1e2 100644 --- a/models/tests/layers/processor/test_transformer_processor.py +++ b/models/tests/layers/processor/test_transformer_processor.py @@ -14,7 +14,7 @@ import pytest import torch -from anemoi.models.layers.processor import TransformerProcessor +from anemoi.models.layers.processor.static import TransformerProcessor from anemoi.models.layers.utils import load_layer_kernels from anemoi.utils.config import DotDict diff --git a/training/src/anemoi/training/diagnostics/callbacks/plot.py b/training/src/anemoi/training/diagnostics/callbacks/plot.py index 384f30fa7..6fa474d04 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/training/src/anemoi/training/diagnostics/callbacks/plot.py @@ -34,7 +34,7 @@ from pytorch_lightning.utilities import rank_zero_only from anemoi.models.layers.graph import NamedNodesAttributes -from anemoi.models.layers.mapper import GraphEdgeMixin +from anemoi.models.layers.mapper.static import GraphEdgeMixin from anemoi.training.diagnostics.plots import argsort_variablename_variablelevel from anemoi.training.diagnostics.plots import get_scatter_frame from anemoi.training.diagnostics.plots import init_plot_settings