diff --git a/models/src/anemoi/models/layers/attention.py b/models/src/anemoi/models/layers/attention.py index e7e285242..2f0e55efb 100644 --- a/models/src/anemoi/models/layers/attention.py +++ b/models/src/anemoi/models/layers/attention.py @@ -135,8 +135,8 @@ def __init__( self.projection = linear(embed_dim, embed_dim, bias=True) if self.qk_norm: - self.q_norm = layer_kernels["QueryNorm"](self.head_dim) - self.k_norm = layer_kernels["KeyNorm"](self.head_dim) + self.q_norm = layer_kernels.QueryNorm(self.head_dim) + self.k_norm = layer_kernels.KeyNorm(self.head_dim) def set_attention_function(self): attn_funcs = { @@ -432,7 +432,9 @@ def __init__(self): try: from flash_attn_interface import flash_attn_func except ImportError: - raise ImportError("Error: Flash-attn v3 not installed. Please build flash-attn/hopper from source to use flash-attn v3") + raise ImportError( + "Error: Flash-attn v3 not installed. Please build flash-attn/hopper from source to use flash-attn v3" + ) self.attention = flash_attn_func @@ -452,12 +454,12 @@ def forward( einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value) ) out = self.attention( - query, - key, - value, - causal=False, - window_size=(window_size, window_size), - )[0] + query, + key, + value, + causal=False, + window_size=(window_size, window_size), + )[0] out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars") return out diff --git a/models/src/anemoi/models/layers/mapper.py b/models/src/anemoi/models/layers/mapper.py index e07b586da..97410211d 100644 --- a/models/src/anemoi/models/layers/mapper.py +++ b/models/src/anemoi/models/layers/mapper.py @@ -704,7 +704,7 @@ def __init__( ) self.node_data_extractor = nn.Sequential( - nn.LayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst) + self.layer_factory.MapperLayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst) ) if initialise_data_extractor_zero: for module in self.node_data_extractor.modules(): @@ -1389,7 +1389,7 @@ def __init__( ) self.node_data_extractor = nn.Sequential( - nn.LayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst) + self.layer_factory.MapperLayerNorm(self.hidden_dim), nn.Linear(self.hidden_dim, self.out_channels_dst) ) def pre_process(self, x, shard_shapes, model_comm_group=None, x_src_is_sharded=False, x_dst_is_sharded=False): diff --git a/models/src/anemoi/models/layers/normalization.py b/models/src/anemoi/models/layers/normalization.py index b188d35db..29d8cd840 100644 --- a/models/src/anemoi/models/layers/normalization.py +++ b/models/src/anemoi/models/layers/normalization.py @@ -31,6 +31,21 @@ def forward(self, x: Tensor) -> Tensor: return super().forward(x).type_as(x) +class AutocastLayerNormCompile(nn.Module): + """Compiling nn.LayerNorm / AutocastLayerNorm as standalone module + using AMP usually means using the default eager execution. + Compiling this module keeps the autocat functionality of AutocastLayerNorm while + also allowing compilation - then fusing both AMP casts and the Triton LayerNorm. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + self.norm = nn.LayerNorm(*args, **kwargs) + + def forward(self, x: Tensor) -> Tensor: + return self.norm(x).type_as(x) + + class ConditionalLayerNorm(nn.Module): """Conditional Layer Normalization. @@ -51,6 +66,7 @@ def __init__( self.bias = nn.Linear(condition_shape, normalized_shape) # , bias=False) self.autocast = autocast + # shouldn't weight here be ones_ and not zeros_? if w_one_bias_zero_init: nn.init.zeros_(self.scale.weight) nn.init.zeros_(self.scale.bias) diff --git a/models/src/anemoi/models/layers/utils.py b/models/src/anemoi/models/layers/utils.py index 417967e22..f4390b122 100644 --- a/models/src/anemoi/models/layers/utils.py +++ b/models/src/anemoi/models/layers/utils.py @@ -11,13 +11,11 @@ import logging from typing import Optional +import torch from hydra.errors import InstantiationException from hydra.utils import instantiate -import torch from torch import nn -from torch import cuda from torch.utils.checkpoint import checkpoint -from contextlib import contextmanager from anemoi.utils.config import DotDict @@ -58,6 +56,7 @@ def load_layer_kernels(kernel_config: Optional[DotDict] = None, instance: bool = default_kernels = { "Linear": {"_target_": "torch.nn.Linear"}, "LayerNorm": {"_target_": "torch.nn.LayerNorm"}, + "MapperLayerNorm": {"_target_": "torch.nn.LayerNorm"}, "Activation": {"_target_": "torch.nn.GELU"}, "QueryNorm": { "_target_": "anemoi.models.layers.normalization.AutocastLayerNorm", @@ -92,20 +91,21 @@ def load_layer_kernels(kernel_config: Optional[DotDict] = None, instance: bool = layer_kernels[name] = kernel_entry return layer_kernels + class ProfilerWrapper(nn.Module): """Wrapper for checkpointing a module.""" def __init__(self, module: nn.Module, marker: str) -> None: super().__init__() self.module = module - self.marker=marker - self.enabled=True + self.marker = marker + self.enabled = True def forward(self, *args, **kwargs): - #print(f"{args=}, {kwargs=}") - #tracing_marker=marker.split('- ')[1].split(', input')[0] - with torch.autograd.profiler.record_function("anemoi-"+self.marker): + # print(f"{args=}, {kwargs=}") + # tracing_marker=marker.split('- ')[1].split(', input')[0] + with torch.autograd.profiler.record_function("anemoi-" + self.marker): out = self.module(*args, **kwargs) return out - + return grad_output # Return unchanged gradients diff --git a/models/src/anemoi/models/models/encoder_processor_decoder.py b/models/src/anemoi/models/models/encoder_processor_decoder.py index e1d586818..c34a17348 100644 --- a/models/src/anemoi/models/models/encoder_processor_decoder.py +++ b/models/src/anemoi/models/models/encoder_processor_decoder.py @@ -28,8 +28,8 @@ from anemoi.models.distributed.shapes import apply_shard_shapes from anemoi.models.distributed.shapes import get_shard_shapes from anemoi.models.layers.graph import NamedNodesAttributes -from anemoi.models.layers.utils import ProfilerWrapper from anemoi.models.layers.mapper import GraphTransformerBaseMapper +from anemoi.models.layers.utils import ProfilerWrapper from anemoi.utils.config import DotDict LOGGER = logging.getLogger(__name__) diff --git a/training/src/anemoi/training/config/model/graphtransformer.yaml b/training/src/anemoi/training/config/model/graphtransformer.yaml index 7e1c8de15..a25f5918c 100644 --- a/training/src/anemoi/training/config/model/graphtransformer.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer.yaml @@ -90,6 +90,7 @@ trainable_parameters: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv attributes: diff --git a/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml b/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml index d848cb74b..71738681e 100644 --- a/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer_diffusion.yaml @@ -125,6 +125,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv #options: # An example of setting torch.compile options #dynamic: false diff --git a/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml b/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml index e92dbaa3e..53c7d223c 100644 --- a/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer_diffusiontend.yaml @@ -125,6 +125,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv #options: # An example of setting torch.compile options #dynamic: false diff --git a/training/src/anemoi/training/config/model/graphtransformer_ens.yaml b/training/src/anemoi/training/config/model/graphtransformer_ens.yaml index a092f47f6..6f91bf0ea 100644 --- a/training/src/anemoi/training/config/model/graphtransformer_ens.yaml +++ b/training/src/anemoi/training/config/model/graphtransformer_ens.yaml @@ -122,6 +122,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv - module: anemoi.models.layers.normalization.ConditionalLayerNorm diff --git a/training/src/anemoi/training/config/model/transformer.yaml b/training/src/anemoi/training/config/model/transformer.yaml index 679e641b0..db109daeb 100644 --- a/training/src/anemoi/training/config/model/transformer.yaml +++ b/training/src/anemoi/training/config/model/transformer.yaml @@ -99,6 +99,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv #options: # An example of setting torch.compile options #dynamic: false diff --git a/training/src/anemoi/training/config/model/transformer_diffusion.yaml b/training/src/anemoi/training/config/model/transformer_diffusion.yaml index d72f17fec..68509c2cb 100644 --- a/training/src/anemoi/training/config/model/transformer_diffusion.yaml +++ b/training/src/anemoi/training/config/model/transformer_diffusion.yaml @@ -127,6 +127,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv #options: # An example of setting torch.compile options #dynamic: false diff --git a/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml b/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml index 247fb7252..19e1beed6 100644 --- a/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml +++ b/training/src/anemoi/training/config/model/transformer_diffusiontend.yaml @@ -127,6 +127,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv #options: # An example of setting torch.compile options #dynamic: false diff --git a/training/src/anemoi/training/config/model/transformer_ens.yaml b/training/src/anemoi/training/config/model/transformer_ens.yaml index 3d4f8daf9..4a1d93bd7 100644 --- a/training/src/anemoi/training/config/model/transformer_ens.yaml +++ b/training/src/anemoi/training/config/model/transformer_ens.yaml @@ -124,6 +124,7 @@ attributes: # non-compiled ("eager") execution for those lines. # options (dict): a dict of further options which can be passed to torch.compile compile: + - module: anemoi.models.layers.normalization.AutocastLayerNormCompile - module: anemoi.models.layers.conv.GraphTransformerConv #options: #dynamic: false diff --git a/training/src/anemoi/training/data/datamodule/singledatamodule.py b/training/src/anemoi/training/data/datamodule/singledatamodule.py index 40ceff958..845322f2c 100644 --- a/training/src/anemoi/training/data/datamodule/singledatamodule.py +++ b/training/src/anemoi/training/data/datamodule/singledatamodule.py @@ -174,14 +174,15 @@ def timeincrement(self) -> int: @cached_property def ds_train(self) -> NativeGridDataset: - dataloader_config=self.config.dataloader.training - #remove start and end date to work with cloned dataset + dataloader_config = self.config.dataloader.training + # remove start and end date to work with cloned dataset import os + if os.getenv("CLONED_DATASET", "0") == "1": - LOGGER.info("changing dataloader config to work with cloned datasets") - dataloader_config.pop("start") - dataloader_config.pop("end") - dataloader_config.pop("frequency") + LOGGER.info("changing dataloader config to work with cloned datasets") + dataloader_config.pop("start") + dataloader_config.pop("end") + dataloader_config.pop("frequency") return self._get_dataset( open_dataset(dataloader_config), label="train", diff --git a/training/src/anemoi/training/data/dataset/singledataset.py b/training/src/anemoi/training/data/dataset/singledataset.py index 1f5b5f2f1..298bdcf19 100644 --- a/training/src/anemoi/training/data/dataset/singledataset.py +++ b/training/src/anemoi/training/data/dataset/singledataset.py @@ -22,7 +22,6 @@ from anemoi.training.data.grid_indices import BaseGridIndices from anemoi.training.utils.seeding import get_base_seed from anemoi.training.utils.usable_indices import get_usable_indices -import os LOGGER = logging.getLogger(__name__) @@ -206,11 +205,11 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: shard_start = self.sample_comm_group_id * shard_size shard_end = (self.sample_comm_group_id + 1) * shard_size - if (os.getenv("DONT_SPLIT_DDP", "0") == "1"): + if os.getenv("DONT_SPLIT_DDP", "0") == "1": LOGGER.info("Not spliting dataset across model instances") - shard_size = len(self.valid_date_indices) + shard_size = len(self.valid_date_indices) shard_start = 0 - shard_end = shard_size + shard_end = shard_size shard_len = shard_end - shard_start self.n_samples_per_worker = shard_len // n_workers diff --git a/training/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py b/training/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py index 810e9cfed..4b3fa7872 100644 --- a/training/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py +++ b/training/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py @@ -9,8 +9,8 @@ import contextlib import sys -import torch +import torch from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor with contextlib.suppress(ImportError): @@ -18,18 +18,19 @@ with contextlib.suppress(ImportError): from pyrsmi import rocml + def parse_memory_stats(device=0): - stats=torch.cuda.memory_stats(device=device) - #need to handle empty dict before memory has been allocated + stats = torch.cuda.memory_stats(device=device) + # need to handle empty dict before memory has been allocated try: - active_mem=stats['active_bytes.all.current'] + active_mem = stats["active_bytes.all.current"] except KeyError: - active_mem=0 + active_mem = 0 try: - allocated_mem=stats['allocated_bytes.all.current'] + allocated_mem = stats["allocated_bytes.all.current"] except KeyError: - allocated_mem=0 - return active_mem,allocated_mem + allocated_mem = 0 + return active_mem, allocated_mem class GreenGPUMonitor(BaseMetricsMonitor): diff --git a/training/src/anemoi/training/diagnostics/profilers.py b/training/src/anemoi/training/diagnostics/profilers.py index 88c1c0065..d423fcac8 100644 --- a/training/src/anemoi/training/diagnostics/profilers.py +++ b/training/src/anemoi/training/diagnostics/profilers.py @@ -352,8 +352,12 @@ def handler_fn(prof: pl.profilers.Profiler) -> None: return handler_fn - global_rank = rank_zero_only.rank #int(os.environ.get("SLURM_PROCID", "0")) # WON'T WORK WHEN RUNNING WITHOUT SLURM - if (self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only and global_rank == 0) or (not self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only): + global_rank = ( + rank_zero_only.rank + ) # int(os.environ.get("SLURM_PROCID", "0")) # WON'T WORK WHEN RUNNING WITHOUT SLURM + if (self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only and global_rank == 0) or ( + not self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only + ): from pytorch_lightning.profilers.pytorch import _KINETO_AVAILABLE assert ( @@ -365,7 +369,7 @@ def handler_fn(prof: pl.profilers.Profiler) -> None: ) self.memory_profiler = PyTorchProfiler( activities=[torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CPU], - with_stack=False, #breaks profiling on aarch64 systems + with_stack=False, # breaks profiling on aarch64 systems emit_nvtx=False, export_to_chrome=True, record_shapes=False, @@ -579,22 +583,18 @@ def _save_extra_plots(self) -> None: self.memory_timeline_fname = str(Path(self.dirpath, "memory_timelines.html")) self.memory_profiler.profiler.export_memory_timeline(self.memory_timeline_fname) - def convert_to_microseconds(self,time_str): - """ - Convert time strings with units (us, ms, s) into microseconds (µs). - """ + def convert_to_microseconds(self, time_str): + """Convert time strings with units (us, ms, s) into microseconds (µs).""" if pd.isna(time_str): return 0 time_str = str(time_str).strip() if time_str.endswith("us"): return float(time_str[:-2]) - elif time_str.endswith("ms"): + if time_str.endswith("ms"): return float(time_str[:-2]) * 1000 - elif time_str.endswith("s"): + if time_str.endswith("s"): return float(time_str[:-1]) * 1_000_000 - else: - return 0 # Default/fallback - + return 0 # Default/fallback @rank_zero_only def get_memory_profiler_df(self) -> pd.DataFrame: @@ -647,11 +647,11 @@ def get_memory_profiler_df(self) -> pd.DataFrame: ] table_main_body = "\n".join(table_main_body) memory_df = pd.read_fwf(StringIO(table_main_body), names=columns, skiprows=2) - #memory_df = memory_df[memory_df["Name"].str.startswith("anemoi-")] #filter to just anemoi markers from nvtx_wrapper + # memory_df = memory_df[memory_df["Name"].str.startswith("anemoi-")] #filter to just anemoi markers from nvtx_wrapper # currently sorted by 'CUDA total', instead sort by 'Self CUDA' memory_df["Self CUDA µs"] = memory_df["Self CUDA"].apply(self.convert_to_microseconds) memory_df = memory_df.sort_values(by="Self CUDA µs", ascending=False) - #memory_df = memory_df.drop(["Self CUDA µs"]) + # memory_df = memory_df.drop(["Self CUDA µs"]) flag = ["--" not in row for row in memory_df["Name"]] memory_df = memory_df[flag] time_rows = [row for row in table.split("\n")[-3:] if row != ""] diff --git a/training/src/anemoi/training/diagnostics/trace_analyser.py b/training/src/anemoi/training/diagnostics/trace_analyser.py index cdaaa7560..bea141842 100644 --- a/training/src/anemoi/training/diagnostics/trace_analyser.py +++ b/training/src/anemoi/training/diagnostics/trace_analyser.py @@ -1,50 +1,49 @@ import json +import logging +import os from collections import defaultdict -import matplotlib.pyplot as plt + import matplotlib.cm as cm +import matplotlib.pyplot as plt import numpy as np import torch -from typing import List, Tuple -import re -import os -import logging from rich.console import Console + LOGGER = logging.getLogger(__name__) console = Console(record=True, width=200) -filename="/ec/res4/scratch/naco/aifs/outputs/raps/train/4N_16gpn_transformer_512c_o1280.jobname-619791/train-outputs/profiler/e5c4bd0471a2494f9183bdf493fafc7a/ac6-306.bullx_621045.None.1743008762865142923.pt.trace.json" +filename = "/ec/res4/scratch/naco/aifs/outputs/raps/train/4N_16gpn_transformer_512c_o1280.jobname-619791/train-outputs/profiler/e5c4bd0471a2494f9183bdf493fafc7a/ac6-306.bullx_621045.None.1743008762865142923.pt.trace.json" + def find_first_trace_file(dirpath: str) -> str | None: - """ - Search the given directory (non-recursively) for the first file ending with '.pt.trace.json'. + """Search the given directory (non-recursively) for the first file ending with '.pt.trace.json'. Args: dirpath (str): Directory to search. - Returns: + Returns + ------- str | None: Full path to the first matching file, or None if not found. """ for filename in sorted(os.listdir(dirpath)): - if filename.endswith('.pt.trace.json'): + if filename.endswith(".pt.trace.json"): return os.path.join(dirpath, filename) return None + def analyze_anemoi_durations(json_file_path): allowed_names = {"anemoi-encoder", "anemoi-decoder", "anemoi-processor"} - #durations = dict() + # durations = dict() durations = defaultdict(list) - with open(json_file_path, 'r') as f: + with open(json_file_path) as f: data = json.load(f) # Some trace files have events inside 'traceEvents' - events = data.get('traceEvents', data) + events = data.get("traceEvents", data) for event in events: - if ( - event.get("cat") == "gpu_user_annotation" and - event.get("name") in allowed_names - ): + if event.get("cat") == "gpu_user_annotation" and event.get("name") in allowed_names: durations[event["name"]].append(event.get("dur", 0)) # Compute total and average durations @@ -52,22 +51,19 @@ def analyze_anemoi_durations(json_file_path): for name, dur_list in durations.items(): total = sum(dur_list) avg = total / len(dur_list) if dur_list else 0 - summary[name] = { - "total_duration": total, - "average_duration": avg, - "count": len(dur_list) - } + summary[name] = {"total_duration": total, "average_duration": avg, "count": len(dur_list)} return summary + def list_unique_kernel_names(json_file_path): unique_names = set() - with open(json_file_path, 'r') as f: + with open(json_file_path) as f: data = json.load(f) # Handle traceEvents wrapping - events = data.get('traceEvents', data) + events = data.get("traceEvents", data) for event in events: if event.get("cat") == "kernel": @@ -79,43 +75,45 @@ def list_unique_kernel_names(json_file_path): for name in sorted(unique_names): console.print(f"- {name}") -import json + + def extract_unique_kernel_function_names(json_file_path): unique_funcs = set() - with open(json_file_path, 'r') as f: + with open(json_file_path) as f: data = json.load(f) events = data.get("traceEvents", data) for event in events: - if event.get("cat") == "kernel" or event.get("cat") == "gpu_memcpy" or event.get("cat") == "gpu_memset": + if event.get("cat") == "kernel" or event.get("cat") == "gpu_memcpy" or event.get("cat") == "gpu_memset": name = event.get("name", "") # Step 1: Remove leading "void " if present - if name.startswith("void "): - name = name[len("void "):] + name = name.removeprefix("void ") # Step 2: Truncate at the first "<", if present name = name.split("<")[0] unique_funcs.add(name) - #print("Unique kernel function names:\n") - #for func in sorted(unique_funcs): + # print("Unique kernel function names:\n") + # for func in sorted(unique_funcs): # print(f"- {func}") return unique_funcs + def sum_kernel_durations(kernel_durations): - # Sort by total duration (descending) + # Sort by total duration (descending) sorted_durations = sorted(kernel_durations.items(), key=lambda x: x[1], reverse=True) - #print("Total durations per kernel (sorted by duration):") - #for kernel_name, total_dur in sorted_durations: + # print("Total durations per kernel (sorted by duration):") + # for kernel_name, total_dur in sorted_durations: # print(f"- {kernel_name}: {total_dur/1000:.3f}s") return sorted_durations + def plot_kernel_durations_pie(sorted_durations, top_n=10): # Split top N and 'Other' top_kernels = sorted_durations[:top_n] @@ -129,16 +127,16 @@ def plot_kernel_durations_pie(sorted_durations, top_n=10): durations.append(other_total) # Generate distinct colors using colormap - cmap = cm.get_cmap('tab20', len(labels)) + cmap = cm.get_cmap("tab20", len(labels)) colors = [cmap(i) for i in range(len(labels))] # Generate hatching for special kernels hatches = [] for name in labels: if "nccl" in name: - hatches.append('////') # Diagonal lines for comms + hatches.append("////") # Diagonal lines for comms elif "Memcpy" in name: - hatches.append('....') # Dots for memcpys + hatches.append("....") # Dots for memcpys else: hatches.append(None) @@ -147,234 +145,277 @@ def plot_kernel_durations_pie(sorted_durations, top_n=10): # Create pie chart fig, ax = plt.subplots(figsize=(10, 8)) - #wedges, texts, autotexts = ax.pie( - wedges, texts= ax.pie( + # wedges, texts, autotexts = ax.pie( + wedges, texts = ax.pie( durations, labels=None, # disable inline labels - #autopct='%1.1f%%', + # autopct='%1.1f%%', startangle=140, colors=colors, - #explode=explode, - wedgeprops=dict(width=0.4, edgecolor='black', linewidth=2), - pctdistance=0.75 + # explode=explode, + wedgeprops=dict(width=0.4, edgecolor="black", linewidth=2), + pctdistance=0.75, ) # Apply hatches - for wedge, hatch in zip(wedges, hatches): + for wedge, hatch in zip(wedges, hatches, strict=False): if hatch: wedge.set_hatch(hatch) # Add legend outside the plot - ax.legend( - wedges, - labels, - title="Kernel Names", - loc="center left", - bbox_to_anchor=(1, 0.5), - fontsize='small' - ) + ax.legend(wedges, labels, title="Kernel Names", loc="center left", bbox_to_anchor=(1, 0.5), fontsize="small") ax.set_title(f"Kernel Duration Distribution (Top {top_n} Shown)", fontsize=14) plt.tight_layout() plt.savefig("kernels.png") - #plt.show() + # plt.show() + def classify_kernel(name: str) -> str: name_lower = name.lower() - if 'memcpy' in name_lower or 'memset' in name_lower: - return 'memory' - elif 'nccl' in name_lower: - return 'comms' - #elif 'cublas' in name_lower or 'cutlass' in name_lower or 'gemm' in name_lower or 'kernel' in name_lower or 'flash' in name_lower: + if "memcpy" in name_lower or "memset" in name_lower: + return "memory" + if "nccl" in name_lower: + return "comms" + # elif 'cublas' in name_lower or 'cutlass' in name_lower or 'gemm' in name_lower or 'kernel' in name_lower or 'flash' in name_lower: # return 'compute' - else: - return 'compute' # Default fallback + return "compute" # Default fallback + + +def print_kernel_table( + data: list[tuple[str, float]], kernel_counts, kernel_weighted_occupancies, top_n: int = 10, num_iterations=20, +): + total_duration_us = sum(duration for _, duration in data) / num_iterations -def print_kernel_table(data: List[Tuple[str, float]], kernel_counts, kernel_weighted_occupancies, top_n: int = 10, num_iterations=20): - total_duration_us = sum(duration for _, duration in data) / num_iterations - # Sort by duration descending data_sorted = sorted(data, key=lambda x: x[1], reverse=True) - + top_kernels = data_sorted[:top_n] remaining = data_sorted[top_n:] # Prepare table rows rows = [] - - #for (name, duration), count, occupancy in zip(top_kernels, kernel_counts, kernel_weighted_occupancies): - total_count=0 + + # for (name, duration), count, occupancy in zip(top_kernels, kernel_counts, kernel_weighted_occupancies): + total_count = 0 for name, duration_us in top_kernels: - count = kernel_counts[name]/num_iterations - total_count+=count - occupancy = kernel_weighted_occupancies[name]/duration_us * 100 + count = kernel_counts[name] / num_iterations + total_count += count + occupancy = kernel_weighted_occupancies[name] / duration_us * 100 category = classify_kernel(name) - percent = ((duration_us/num_iterations) / total_duration_us) * 100 - rows.append((name, category, duration_us/1000000/num_iterations, percent, count, occupancy)) - + percent = ((duration_us / num_iterations) / total_duration_us) * 100 + rows.append((name, category, duration_us / 1000000 / num_iterations, percent, count, occupancy)) + # Aggregate "other" if remaining: - other_duration_us = sum(d for _, d in remaining)/num_iterations + other_duration_us = sum(d for _, d in remaining) / num_iterations other_percent = (other_duration_us / total_duration_us) * 100 - other_count=sum(kernel_counts[name] for name, _ in remaining)/num_iterations - total_count+=other_count - other_occupancy=0.0 - rows.append(("other", "-", other_duration_us/1000000/num_iterations, other_percent, other_count, other_occupancy)) - + other_count = sum(kernel_counts[name] for name, _ in remaining) / num_iterations + total_count += other_count + other_occupancy = 0.0 + rows.append( + ("other", "-", other_duration_us / 1000000 / num_iterations, other_percent, other_count, other_occupancy), + ) + # Add total row - rows.append(("total (kernels on different streams can overlap)", "-", total_duration_us /1000000, 100.0, total_count, 0.0)) + rows.append( + ("total (kernels on different streams can overlap)", "-", total_duration_us / 1000000, 100.0, total_count, 0.0), + ) # Print formatted table - console.print(f"{'Kernel':60} {'Category':10} {'Duration (s)':>15} {'% Time':>10} {'Count':>10} {'% Occupancy':>10}") + console.print( + f"{'Kernel':60} {'Category':10} {'Duration (s)':>15} {'% Time':>10} {'Count':>10} {'% Occupancy':>10}", + ) console.print("-" * 100) for name, category, duration, percent, count, occupancy in rows: console.print(f"{name[:58]:60} {category:10} {duration:15.2f} {percent:10.2f} {count:10} {occupancy:10.2f}") + def count_iterations(json_file_path): - with open(json_file_path, 'r') as f: + with open(json_file_path) as f: data = json.load(f) events = data.get("traceEvents", data) for event in events: - iteration_count=0 + iteration_count = 0 if event.get("cat") == "python_function" and ("transfer_batch_to_device" in event.get("name")): - iteration_count+=1 + iteration_count += 1 return iteration_count + def compute_av_time_per_iter_and_dl_stalls(iteration_durations_us, dataloading_stall_durations_us): - #print(f"{iteration_durations_us=}, {dataloading_stall_durations_us=}") - dataloading_stall_durations_s=np.array(dataloading_stall_durations_us)/1000000 - iteration_durations_s=np.array(iteration_durations_us)/1000000 - #iteration_durations = [iteration_start_times[i+1] - iteration_start_times[i] for i in range(len(iteration_start_times) - 1)] - dataloading_stall_percentages= dataloading_stall_durations_s / iteration_durations_s * 100 - #print(f"{iteration_durations_s=} {dataloading_stall_durations_s=} {dataloading_stall_percentages=}%") - #using median here to minimise the impact of long tails in the distribution - av_iteration_duration_s=np.median(iteration_durations_s) - av_throughput=1/av_iteration_duration_s - av_dataloading_stall_duration_s=np.median(dataloading_stall_durations_s) - av_dataloading_stall_percentage=np.median(dataloading_stall_percentages) - console.print(f"Each training iteration took an average of {av_iteration_duration_s:.2f}s ({av_throughput:.2f} iterations per second)") - console.print(f"An average of {av_dataloading_stall_duration_s:.2f}s ({av_dataloading_stall_percentage:.2f}%) of each iteration was spent idling while loading data") + # print(f"{iteration_durations_us=}, {dataloading_stall_durations_us=}") + dataloading_stall_durations_s = np.array(dataloading_stall_durations_us) / 1000000 + iteration_durations_s = np.array(iteration_durations_us) / 1000000 + # iteration_durations = [iteration_start_times[i+1] - iteration_start_times[i] for i in range(len(iteration_start_times) - 1)] + dataloading_stall_percentages = dataloading_stall_durations_s / iteration_durations_s * 100 + # print(f"{iteration_durations_s=} {dataloading_stall_durations_s=} {dataloading_stall_percentages=}%") + # using median here to minimise the impact of long tails in the distribution + av_iteration_duration_s = np.median(iteration_durations_s) + av_throughput = 1 / av_iteration_duration_s + av_dataloading_stall_duration_s = np.median(dataloading_stall_durations_s) + av_dataloading_stall_percentage = np.median(dataloading_stall_percentages) + console.print( + f"Each training iteration took an average of {av_iteration_duration_s:.2f}s ({av_throughput:.2f} iterations per second)", + ) + console.print( + f"An average of {av_dataloading_stall_duration_s:.2f}s ({av_dataloading_stall_percentage:.2f}%) of each iteration was spent idling while loading data", + ) if av_dataloading_stall_percentage > 5.0: - console.print(f"Warning! Dataloading stall times are high. You can try increase the number of dataloader workers. If you are limited by CPU memory, you can try decrease prefetch factor to 1 to further increase the number of workers") + console.print( + "Warning! Dataloading stall times are high. You can try increase the number of dataloader workers. If you are limited by CPU memory, you can try decrease prefetch factor to 1 to further increase the number of workers", + ) return iteration_durations_s, dataloading_stall_durations_s, dataloading_stall_percentages - + + def analyse_HtoD_memcpy(batch_sizes_GB, batch_transfer_bw_GBs, batch_transfer_durations_us): # "ph": "X", "cat": "gpu_memcpy", "name": "Memcpy HtoD (Pinned -> Device)", "pid": 0, "tid": 7, "ts": 7376383457086.139, "dur": 19558.546, "args": { "External id": 77569, "device": 0, "context": 1, "stream": 7, "correlation": 436228, "bytes": 499925760, "memory bandwidth (GB/s)": 25.560476734824768} - - batch_transfer_durations_us=np.array(batch_transfer_durations_us) - - #print(f"{batch_sizes_GB=} {batch_transfer_durations_us/1000000=} {batch_transfer_bw_GBs=}") - av_batch_size_GB=np.mean(batch_sizes_GB) - av_batch_transfer_bw_GBs=np.mean(batch_transfer_bw_GBs) - av_batch_transfer_durations_s=np.mean(batch_transfer_durations_us)/1000000 - console.print(f"{av_batch_size_GB=:.2f}GB, {av_batch_transfer_durations_s=:.2f}s, ({av_batch_transfer_bw_GBs=:.2f}GB/s)") + + batch_transfer_durations_us = np.array(batch_transfer_durations_us) + + # print(f"{batch_sizes_GB=} {batch_transfer_durations_us/1000000=} {batch_transfer_bw_GBs=}") + av_batch_size_GB = np.mean(batch_sizes_GB) + av_batch_transfer_bw_GBs = np.mean(batch_transfer_bw_GBs) + av_batch_transfer_durations_s = np.mean(batch_transfer_durations_us) / 1000000 + console.print( + f"{av_batch_size_GB=:.2f}GB, {av_batch_transfer_durations_s=:.2f}s, ({av_batch_transfer_bw_GBs=:.2f}GB/s)", + ) return av_batch_size_GB, av_batch_transfer_bw_GBs, av_batch_transfer_durations_s - - + + def parse_json_trace_file(json_file_path): - """ - This function iterates once over a json trace file and returns all the information we will later analyse - """ - with open(json_file_path, 'r') as f: + """This function iterates once over a json trace file and returns all the information we will later analyse""" + with open(json_file_path) as f: data = json.load(f) events = data.get("traceEvents", data) - - #HtoD memcpy analysis - batch_sizes_GB=[] - batch_transfer_bw_GBs=[] - batch_transfer_durations_us=[] - - #Av. iter time and dataloading stall analysis - dataloading_stall_durations_us=[] - iteration_durations_us=[] - - #kernel analysis - #unique_kernel_names = set() + + # HtoD memcpy analysis + batch_sizes_GB = [] + batch_transfer_bw_GBs = [] + batch_transfer_durations_us = [] + + # Av. iter time and dataloading stall analysis + dataloading_stall_durations_us = [] + iteration_durations_us = [] + + # kernel analysis + # unique_kernel_names = set() kernel_durations = defaultdict(float) - kernel_counts= defaultdict(int) + kernel_counts = defaultdict(int) kernel_weighted_occupancies = defaultdict(float) - - main_stream=7 #assumption - gpu_idle_time=0 - prev_kernel_end_time=0 - - iteration_count=0 - + + main_stream = 7 # assumption + gpu_idle_time = 0 + prev_kernel_end_time = 0 + + iteration_count = 0 + for event in events: - if event.get("cat") == "kernel" or event.get("cat") == "gpu_memcpy" or event.get("cat") == "gpu_memset": - #get unique kernel names + if event.get("cat") == "kernel" or event.get("cat") == "gpu_memcpy" or event.get("cat") == "gpu_memset": + # get unique kernel names name = event.get("name", "") # Step 1: Remove leading "void " if present - if name.startswith("void "): - name = name[len("void "):] + name = name.removeprefix("void ") # Step 2: Truncate at the first "<", if present name = name.split("<")[0] - #unique_kernel_names.add(name) - - #get durations for each kernel + # unique_kernel_names.add(name) + + # get durations for each kernel kernel_duration = event.get("dur", 0) - kernel_occupancy_pct= event.get("args").get('est. achieved occupancy %', 0) / 100 + kernel_occupancy_pct = event.get("args").get("est. achieved occupancy %", 0) / 100 kernel_durations[name] += kernel_duration kernel_counts[name] += 1 kernel_weighted_occupancies[name] += kernel_occupancy_pct * kernel_duration - + # analyse iter time and dataloading stall time if event.get("cat") == "user_annotation" and ("train_dataloader_next" in event.get("name")): - dataloading_stall_durations_us.append(event.get("dur")) + dataloading_stall_durations_us.append(event.get("dur")) if event.get("cat") == "user_annotation" and ("run_training_batch" in event.get("name")): - iteration_durations_us.append(event.get("dur")) - - #if event.get("cat") == "python_function" and ("transfer_batch_to_device" in event.get("name")): + iteration_durations_us.append(event.get("dur")) + + # if event.get("cat") == "python_function" and ("transfer_batch_to_device" in event.get("name")): if event.get("cat") == "user_annotation" and ("transfer_batch_to_device" in event.get("name")): - iteration_count+=1 - + iteration_count += 1 + # analyse HtoD memcpy if event.get("cat") == "gpu_memcpy" and ("Memcpy HtoD (Pinned" in event.get("name")): - batch_transfer_durations_us.append(event.get("dur")) - batch_sizes_GB.append(event.get("args")['bytes']/1000/1000/1000) #use 1000 not 1024 as it matches the numbers reported by the tracer - batch_transfer_bw_GBs.append(event.get("args")['memory bandwidth (GB/s)']) - - #should i include memcpy here? + batch_transfer_durations_us.append(event.get("dur")) + batch_sizes_GB.append( + event.get("args")["bytes"] / 1000 / 1000 / 1000, + ) # use 1000 not 1024 as it matches the numbers reported by the tracer + batch_transfer_bw_GBs.append(event.get("args")["memory bandwidth (GB/s)"]) + + # should i include memcpy here? if event.get("cat") == "kernel" and event.get("args").get("stream") == main_stream: - kernel_end_time=event.get("ts")+ event.get("dur") - kernel_start_time=event.get("ts") + kernel_end_time = event.get("ts") + event.get("dur") + kernel_start_time = event.get("ts") if prev_kernel_end_time != 0: diff = kernel_start_time - prev_kernel_end_time if diff > 0: gpu_idle_time += diff prev_kernel_end_time = kernel_end_time - + print(f"gpu_idle_time = {gpu_idle_time/1000000}s") print(f"gpu_idle_time per iteration = {gpu_idle_time/1000000/iteration_count}s") - return batch_sizes_GB, batch_transfer_bw_GBs, batch_transfer_durations_us, dataloading_stall_durations_us, iteration_durations_us, kernel_durations, kernel_counts, kernel_weighted_occupancies, iteration_count - + return ( + batch_sizes_GB, + batch_transfer_bw_GBs, + batch_transfer_durations_us, + dataloading_stall_durations_us, + iteration_durations_us, + kernel_durations, + kernel_counts, + kernel_weighted_occupancies, + iteration_count, + ) + + def analyse_gpu_memory_usage(): - max_available_memory_GB=torch.cuda.get_device_properties().total_memory/1024/1024/1024 #todo should guard this incase the syntax is different on AMD - max_reserved_memory_GB=torch.cuda.max_memory_reserved() /1024/1024/1024 - max_allocated_memory_GB=torch.cuda.max_memory_allocated() /1024/1024/1024 + max_available_memory_GB = ( + torch.cuda.get_device_properties().total_memory / 1024 / 1024 / 1024 + ) # TODO should guard this incase the syntax is different on AMD + max_reserved_memory_GB = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024 + max_allocated_memory_GB = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 console.print(f"{max_available_memory_GB=:.2f}, {max_reserved_memory_GB=:.2f}, {max_allocated_memory_GB=:.2f}") - - max_reserved_but_unused_memory_GB=max_reserved_memory_GB - max_allocated_memory_GB + + max_reserved_but_unused_memory_GB = max_reserved_memory_GB - max_allocated_memory_GB if max_reserved_but_unused_memory_GB > 2: - console.print(f"Warning! you have {max_reserved_but_unused_memory_GB}GB of memory reserved by pytorch but not actively allocated. This memory fragmentation can result in avoidable Out-Of-Memory errors. You can try 'export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True' to reduce this memory fragmentation.") - - max_allocated_memory_percentage= max_allocated_memory_GB / max_available_memory_GB * 100 + console.print( + f"Warning! you have {max_reserved_but_unused_memory_GB}GB of memory reserved by pytorch but not actively allocated. This memory fragmentation can result in avoidable Out-Of-Memory errors. You can try 'export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True' to reduce this memory fragmentation.", + ) + + max_allocated_memory_percentage = max_allocated_memory_GB / max_available_memory_GB * 100 console.print(f"Peak (allocated) memory usage was {max_allocated_memory_percentage:.2f}% of total device memory") if max_allocated_memory_percentage < 50.0: - console.print(f"Warning! Your peak device memory usage is low. You could try increasing the batch size or reducing the number of GPUs") -#analyze_anemoi_durations(filename) -#unique_funcs=extract_unique_kernel_function_names(filename) -#durations =sum_kernel_durations(filename, unique_funcs) -#print(durations) + console.print( + "Warning! Your peak device memory usage is low. You could try increasing the batch size or reducing the number of GPUs", + ) + + +# analyze_anemoi_durations(filename) +# unique_funcs=extract_unique_kernel_function_names(filename) +# durations =sum_kernel_durations(filename, unique_funcs) +# print(durations) def analyse_trace(dirpath): - filename=find_first_trace_file(dirpath) + filename = find_first_trace_file(dirpath) console.print(f"Analysing {filename}") - batch_sizes_GB, batch_transfer_bw_GBs, batch_transfer_durations_us, dataloading_stall_durations_us, iteration_durations_us, kernel_durations , kernel_counts, kernel_weighted_occupancies, iteration_count= parse_json_trace_file(filename) - kernel_durations=sum_kernel_durations(kernel_durations) - print_kernel_table(kernel_durations, kernel_counts, kernel_weighted_occupancies, top_n=10, num_iterations=iteration_count) + ( + batch_sizes_GB, + batch_transfer_bw_GBs, + batch_transfer_durations_us, + dataloading_stall_durations_us, + iteration_durations_us, + kernel_durations, + kernel_counts, + kernel_weighted_occupancies, + iteration_count, + ) = parse_json_trace_file(filename) + kernel_durations = sum_kernel_durations(kernel_durations) + print_kernel_table( + kernel_durations, kernel_counts, kernel_weighted_occupancies, top_n=10, num_iterations=iteration_count, + ) console.print("\n") compute_av_time_per_iter_and_dl_stalls(iteration_durations_us, dataloading_stall_durations_us) console.print("\n") @@ -382,21 +423,34 @@ def analyse_trace(dirpath): console.print("\n") analyse_gpu_memory_usage() -if __name__ == '__main__': - #filename="/ec/res4/scratch/naco/aifs/outputs/raps/train/4N_16gpn_transformer_512c_o1280.jobname-619791/train-outputs/profiler/e5c4bd0471a2494f9183bdf493fafc7a/ac6-306.bullx_621045.None.1743008762865142923.pt.trace.json" # old o2560 512c with multi gpu - #filename="/ec/res4/hpcperm/naco/aifs/tests/profiler_update/outputs/profiler/c78cfd5d829340549407acb1e5b36426/ac6-302.bullx_3183374.None.1751544480627283549.pt.trace.json" #20 steps o96 1 gpu - filename="/ec/res4/hpcperm/naco/aifs/tests/profiler_update/outputs/profiler/9a57351d759c428b92964af00c75b767/ac6-301.bullx_3890366.None.1751551147735771561.pt.trace.json" #20 steps o96 4 gpus - batch_sizes_GB, batch_transfer_bw_GBs, batch_transfer_durations_us, dataloading_stall_durations_us, iteration_durations_us, kernel_durations, kernel_counts, kernel_weighted_occupancies, iteration_count = parse_json_trace_file(filename) - #print(kernel_durations) - kernel_durations=sum_kernel_durations(kernel_durations) - print_kernel_table(kernel_durations, kernel_counts, kernel_weighted_occupancies, top_n=10, num_iterations=iteration_count) - #print(durations) - #plot_kernel_durations_pie(durations, top_n=20) + +if __name__ == "__main__": + # filename="/ec/res4/scratch/naco/aifs/outputs/raps/train/4N_16gpn_transformer_512c_o1280.jobname-619791/train-outputs/profiler/e5c4bd0471a2494f9183bdf493fafc7a/ac6-306.bullx_621045.None.1743008762865142923.pt.trace.json" # old o2560 512c with multi gpu + # filename="/ec/res4/hpcperm/naco/aifs/tests/profiler_update/outputs/profiler/c78cfd5d829340549407acb1e5b36426/ac6-302.bullx_3183374.None.1751544480627283549.pt.trace.json" #20 steps o96 1 gpu + filename = "/ec/res4/hpcperm/naco/aifs/tests/profiler_update/outputs/profiler/9a57351d759c428b92964af00c75b767/ac6-301.bullx_3890366.None.1751551147735771561.pt.trace.json" # 20 steps o96 4 gpus + ( + batch_sizes_GB, + batch_transfer_bw_GBs, + batch_transfer_durations_us, + dataloading_stall_durations_us, + iteration_durations_us, + kernel_durations, + kernel_counts, + kernel_weighted_occupancies, + iteration_count, + ) = parse_json_trace_file(filename) + # print(kernel_durations) + kernel_durations = sum_kernel_durations(kernel_durations) + print_kernel_table( + kernel_durations, kernel_counts, kernel_weighted_occupancies, top_n=10, num_iterations=iteration_count, + ) + # print(durations) + # plot_kernel_durations_pie(durations, top_n=20) print("\n") compute_av_time_per_iter_and_dl_stalls(iteration_durations_us, dataloading_stall_durations_us) print("\n") analyse_HtoD_memcpy(batch_sizes_GB, batch_transfer_bw_GBs, batch_transfer_durations_us) print("\n") - x = torch.empty(1024, 1024, 1024, device='cuda', dtype=torch.float32) + x = torch.empty(1024, 1024, 1024, device="cuda", dtype=torch.float32) analyse_gpu_memory_usage() - #print(torch.cuda.get_device_properties()) + # print(torch.cuda.get_device_properties()) diff --git a/training/src/anemoi/training/train/profiler.py b/training/src/anemoi/training/train/profiler.py index 3b0058f79..5db96792b 100644 --- a/training/src/anemoi/training/train/profiler.py +++ b/training/src/anemoi/training/train/profiler.py @@ -27,8 +27,8 @@ from anemoi.training.data.datamodule import AnemoiDatasetsDataModule from anemoi.training.diagnostics.profilers import BenchmarkProfiler from anemoi.training.diagnostics.profilers import ProfilerProgressBar -from anemoi.training.train.train import AnemoiTrainer from anemoi.training.diagnostics.trace_analyser import analyse_trace +from anemoi.training.train.train import AnemoiTrainer LOGGER = logging.getLogger(__name__) console = Console(record=True, width=200) diff --git a/training/src/anemoi/training/utils/compile.py b/training/src/anemoi/training/utils/compile.py index e7b9cd33b..c365f4d60 100644 --- a/training/src/anemoi/training/utils/compile.py +++ b/training/src/anemoi/training/utils/compile.py @@ -55,6 +55,13 @@ def _meets_library_versions_for_compile() -> bool: msg += "Please upgrade these libraries to enable compilation." LOGGER.warning(msg) + # Dynamo has a limit on the number of recompilations before falling back to eager mode. + # Typically, this is to allow dynamic shapes. In a context of multiple different but static + # shapes (e.g. different graphs sizes in different chunks) it is useful to increase this limit. + import torch._dynamo as dynamo + + dynamo.config.recompile_limit = 32 + return version_req and has_triton