Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions models/src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def __init__(
self.projection = linear(embed_dim, embed_dim, bias=True)

if self.qk_norm:
self.q_norm = layer_kernels["QueryNorm"](self.head_dim)
self.k_norm = layer_kernels["KeyNorm"](self.head_dim)
self.q_norm = layer_kernels.QueryNorm(self.head_dim)
self.k_norm = layer_kernels.KeyNorm(self.head_dim)

def set_attention_function(self):
attn_funcs = {
Expand Down Expand Up @@ -432,7 +432,9 @@ def __init__(self):
try:
from flash_attn_interface import flash_attn_func
except ImportError:
raise ImportError("Error: Flash-attn v3 not installed. Please build flash-attn/hopper from source to use flash-attn v3")
raise ImportError(
"Error: Flash-attn v3 not installed. Please build flash-attn/hopper from source to use flash-attn v3"
)

self.attention = flash_attn_func

Expand All @@ -452,12 +454,12 @@ def forward(
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
out = self.attention(
query,
key,
value,
causal=False,
window_size=(window_size, window_size),
)[0]
query,
key,
value,
causal=False,
window_size=(window_size, window_size),
)[0]
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
return out

Expand Down
4 changes: 2 additions & 2 deletions models/src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def __init__(
)

self.node_data_extractor = nn.Sequential(
nn.LayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst)
self.layer_factory.MapperLayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst)
)
if initialise_data_extractor_zero:
for module in self.node_data_extractor.modules():
Expand Down Expand Up @@ -1389,7 +1389,7 @@ def __init__(
)

self.node_data_extractor = nn.Sequential(
nn.LayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst)
self.layer_factory.MapperLayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst)
)

def pre_process(self, x, shard_shapes, model_comm_group=None, x_src_is_sharded=False, x_dst_is_sharded=False):
Expand Down
16 changes: 16 additions & 0 deletions models/src/anemoi/models/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def forward(self, x: Tensor) -> Tensor:
return super().forward(x).type_as(x)


class AutocastLayerNormCompile(nn.Module):
"""Compiling nn.LayerNorm / AutocastLayerNorm as standalone module
using AMP usually means using the default eager execution.
Compiling this module keeps the autocat functionality of AutocastLayerNorm while
also allowing compilation - then fusing both AMP casts and the Triton LayerNorm.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__()
self.norm = nn.LayerNorm(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
return self.norm(x).type_as(x)


class ConditionalLayerNorm(nn.Module):
"""Conditional Layer Normalization.

Expand All @@ -51,6 +66,7 @@ def __init__(
self.bias = nn.Linear(condition_shape, normalized_shape) # , bias=False)
self.autocast = autocast

# shouldn't weight here be ones_ and not zeros_?
if w_one_bias_zero_init:
nn.init.zeros_(self.scale.weight)
nn.init.zeros_(self.scale.bias)
Expand Down
18 changes: 9 additions & 9 deletions models/src/anemoi/models/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
import logging
from typing import Optional

import torch
from hydra.errors import InstantiationException
from hydra.utils import instantiate
import torch
from torch import nn
from torch import cuda
from torch.utils.checkpoint import checkpoint
from contextlib import contextmanager

from anemoi.utils.config import DotDict

Expand Down Expand Up @@ -58,6 +56,7 @@ def load_layer_kernels(kernel_config: Optional[DotDict] = None, instance: bool =
default_kernels = {
"Linear": {"_target_": "torch.nn.Linear"},
"LayerNorm": {"_target_": "torch.nn.LayerNorm"},
"MapperLayerNorm": {"_target_": "torch.nn.LayerNorm"},
"Activation": {"_target_": "torch.nn.GELU"},
"QueryNorm": {
"_target_": "anemoi.models.layers.normalization.AutocastLayerNorm",
Expand Down Expand Up @@ -92,20 +91,21 @@ def load_layer_kernels(kernel_config: Optional[DotDict] = None, instance: bool =
layer_kernels[name] = kernel_entry
return layer_kernels


class ProfilerWrapper(nn.Module):
"""Wrapper for checkpointing a module."""

def __init__(self, module: nn.Module, marker: str) -> None:
super().__init__()
self.module = module
self.marker=marker
self.enabled=True
self.marker = marker
self.enabled = True

def forward(self, *args, **kwargs):
#print(f"{args=}, {kwargs=}")
#tracing_marker=marker.split('- ')[1].split(', input')[0]
with torch.autograd.profiler.record_function("anemoi-"+self.marker):
# print(f"{args=}, {kwargs=}")
# tracing_marker=marker.split('- ')[1].split(', input')[0]
with torch.autograd.profiler.record_function("anemoi-" + self.marker):
out = self.module(*args, **kwargs)
return out

return grad_output # Return unchanged gradients
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from anemoi.models.distributed.shapes import apply_shard_shapes
from anemoi.models.distributed.shapes import get_shard_shapes
from anemoi.models.layers.graph import NamedNodesAttributes
from anemoi.models.layers.utils import ProfilerWrapper
from anemoi.models.layers.mapper import GraphTransformerBaseMapper
from anemoi.models.layers.utils import ProfilerWrapper
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ trainable_parameters:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv

attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
#options: # An example of setting torch.compile options
#dynamic: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
#options: # An example of setting torch.compile options
#dynamic: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
- module: anemoi.models.layers.normalization.ConditionalLayerNorm

Expand Down
1 change: 1 addition & 0 deletions training/src/anemoi/training/config/model/transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
#options: # An example of setting torch.compile options
#dynamic: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
#options: # An example of setting torch.compile options
#dynamic: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
#options: # An example of setting torch.compile options
#dynamic: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ attributes:
# non-compiled ("eager") execution for those lines.
# options (dict): a dict of further options which can be passed to torch.compile
compile:
- module: anemoi.models.layers.normalization.AutocastLayerNormCompile
- module: anemoi.models.layers.conv.GraphTransformerConv
#options:
#dynamic: false
Expand Down
13 changes: 7 additions & 6 deletions training/src/anemoi/training/data/datamodule/singledatamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,15 @@ def timeincrement(self) -> int:

@cached_property
def ds_train(self) -> NativeGridDataset:
dataloader_config=self.config.dataloader.training
#remove start and end date to work with cloned dataset
dataloader_config = self.config.dataloader.training
# remove start and end date to work with cloned dataset
import os

if os.getenv("CLONED_DATASET", "0") == "1":
LOGGER.info("changing dataloader config to work with cloned datasets")
dataloader_config.pop("start")
dataloader_config.pop("end")
dataloader_config.pop("frequency")
LOGGER.info("changing dataloader config to work with cloned datasets")
dataloader_config.pop("start")
dataloader_config.pop("end")
dataloader_config.pop("frequency")
return self._get_dataset(
open_dataset(dataloader_config),
label="train",
Expand Down
7 changes: 3 additions & 4 deletions training/src/anemoi/training/data/dataset/singledataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from anemoi.training.data.grid_indices import BaseGridIndices
from anemoi.training.utils.seeding import get_base_seed
from anemoi.training.utils.usable_indices import get_usable_indices
import os

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -206,11 +205,11 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
shard_start = self.sample_comm_group_id * shard_size
shard_end = (self.sample_comm_group_id + 1) * shard_size

if (os.getenv("DONT_SPLIT_DDP", "0") == "1"):
if os.getenv("DONT_SPLIT_DDP", "0") == "1":
LOGGER.info("Not spliting dataset across model instances")
shard_size = len(self.valid_date_indices)
shard_size = len(self.valid_date_indices)
shard_start = 0
shard_end = shard_size
shard_end = shard_size

shard_len = shard_end - shard_start
self.n_samples_per_worker = shard_len // n_workers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@

import contextlib
import sys
import torch

import torch
from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor

with contextlib.suppress(ImportError):
import pynvml
with contextlib.suppress(ImportError):
from pyrsmi import rocml


def parse_memory_stats(device=0):
stats=torch.cuda.memory_stats(device=device)
#need to handle empty dict before memory has been allocated
stats = torch.cuda.memory_stats(device=device)
# need to handle empty dict before memory has been allocated
try:
active_mem=stats['active_bytes.all.current']
active_mem = stats["active_bytes.all.current"]
except KeyError:
active_mem=0
active_mem = 0
try:
allocated_mem=stats['allocated_bytes.all.current']
allocated_mem = stats["allocated_bytes.all.current"]
except KeyError:
allocated_mem=0
return active_mem,allocated_mem
allocated_mem = 0
return active_mem, allocated_mem


class GreenGPUMonitor(BaseMetricsMonitor):
Expand Down
28 changes: 14 additions & 14 deletions training/src/anemoi/training/diagnostics/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,12 @@ def handler_fn(prof: pl.profilers.Profiler) -> None:

return handler_fn

global_rank = rank_zero_only.rank #int(os.environ.get("SLURM_PROCID", "0")) # WON'T WORK WHEN RUNNING WITHOUT SLURM
if (self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only and global_rank == 0) or (not self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only):
global_rank = (
rank_zero_only.rank
) # int(os.environ.get("SLURM_PROCID", "0")) # WON'T WORK WHEN RUNNING WITHOUT SLURM
if (self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only and global_rank == 0) or (
not self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only
):
from pytorch_lightning.profilers.pytorch import _KINETO_AVAILABLE

assert (
Expand All @@ -365,7 +369,7 @@ def handler_fn(prof: pl.profilers.Profiler) -> None:
)
self.memory_profiler = PyTorchProfiler(
activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU],
with_stack=False, #breaks profiling on aarch64 systems
with_stack=False, # breaks profiling on aarch64 systems
emit_nvtx=False,
export_to_chrome=True,
record_shapes=False,
Expand Down Expand Up @@ -579,22 +583,18 @@ def _save_extra_plots(self) -> None:
self.memory_timeline_fname = str(Path(self.dirpath, "memory_timelines.html"))
self.memory_profiler.profiler.export_memory_timeline(self.memory_timeline_fname)

def convert_to_microseconds(self,time_str):
"""
Convert time strings with units (us, ms, s) into microseconds (µs).
"""
def convert_to_microseconds(self, time_str):
"""Convert time strings with units (us, ms, s) into microseconds (µs)."""
if pd.isna(time_str):
return 0
time_str = str(time_str).strip()
if time_str.endswith("us"):
return float(time_str[:-2])
elif time_str.endswith("ms"):
if time_str.endswith("ms"):
return float(time_str[:-2]) * 1000
elif time_str.endswith("s"):
if time_str.endswith("s"):
return float(time_str[:-1]) * 1_000_000
else:
return 0 # Default/fallback

return 0 # Default/fallback

@rank_zero_only
def get_memory_profiler_df(self) -> pd.DataFrame:
Expand Down Expand Up @@ -647,11 +647,11 @@ def get_memory_profiler_df(self) -> pd.DataFrame:
]
table_main_body = "\n".join(table_main_body)
memory_df = pd.read_fwf(StringIO(table_main_body), names=columns, skiprows=2)
#memory_df = memory_df[memory_df["Name"].str.startswith("anemoi-")] #filter to just anemoi markers from nvtx_wrapper
# memory_df = memory_df[memory_df["Name"].str.startswith("anemoi-")] #filter to just anemoi markers from nvtx_wrapper
# currently sorted by 'CUDA total', instead sort by 'Self CUDA'
memory_df["Self CUDA µs"] = memory_df["Self CUDA"].apply(self.convert_to_microseconds)
memory_df = memory_df.sort_values(by="Self CUDA µs", ascending=False)
#memory_df = memory_df.drop(["Self CUDA µs"])
# memory_df = memory_df.drop(["Self CUDA µs"])
flag = ["--" not in row for row in memory_df["Name"]]
memory_df = memory_df[flag]
time_rows = [row for row in table.split("\n")[-3:] if row != ""]
Expand Down
Loading
Loading