Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests-hpc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
#SBATCH --qos=ng
#SBATCH --gpus=1
#SBATCH --gres=gpu:1
#SBATCH --mem=30G
#SBATCH --mem=64G
troika_user: ${{ secrets.HPC_CI_INTEGRATION_USER }}
benchmark-tests:
runs-on: hpc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
trainer_callbacks.extend(instantiate(callback, config) for callback in config.diagnostics.callbacks)

# Plotting callbacks
trainer_callbacks.extend(instantiate(callback, config) for callback in config.diagnostics.plot.callbacks)
if config["training"]["model_task"] != "anemoi.training.train.tasks.GraphInterpolator":
trainer_callbacks.extend(instantiate(callback, config) for callback in config.diagnostics.plot.callbacks)
else:
LOGGER.info("Plotting callbacks have been temporarily deactivated for the TimeInterpolator")

# Extend with config enabled callbacks
trainer_callbacks.extend(_get_config_enabled_callbacks(config))
Expand Down
8 changes: 7 additions & 1 deletion training/src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,13 @@ def _plot(
else:
LOGGER.warning("There are no trainable node attributes to plot.")

if len(edge_trainable_modules := self.get_edge_trainable_modules(model)):
from anemoi.models.models import AnemoiModelEncProcDecHierarchical

if isinstance(model, AnemoiModelEncProcDecHierarchical):
LOGGER.warning(
"Edge trainable features are not supported for Hierarchical models, skipping plot generation.",
)
elif len(edge_trainable_modules := self.get_edge_trainable_modules(model)):
fig = plot_graph_edge_features(model, edge_trainable_modules, q_extreme_limit=self.q_extreme_limit)

self._output_figure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def on_validation_batch_end(
trainer,
pl_module,
outputs,
batch[0][:, :, 0, :, :],
batch[:, :, 0, :, :],
batch_idx,
)

Expand Down
42 changes: 42 additions & 0 deletions training/tests/integration/config/test_ensemble_crps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,45 @@ dataloader:
end: 2017-01-08 12:00:00
validation:
start: 2017-01-08 18:00:00

diagnostics:
plot:
callbacks:
# Add plot callbacks here.
- _target_: anemoi.training.diagnostics.callbacks.plot_ens.PlotEnsSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 2
parameters: ${diagnostics.plot.parameters}
every_n_batches: ${diagnostics.plot.frequency.batch}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
members: null # None for all members, list for specific members

# Deterministic callbacks are also overloaded.
- _target_: anemoi.training.diagnostics.callbacks.plot_ens.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
every_n_batches: ${diagnostics.plot.frequency.batch}
- _target_: anemoi.training.diagnostics.callbacks.plot_ens.PlotSpectrum
sample_idx: ${diagnostics.plot.sample_idx}
parameters: ${diagnostics.plot.parameters}
every_n_batches: ${diagnostics.plot.frequency.batch}
- _target_: anemoi.training.diagnostics.callbacks.plot_ens.PlotHistogram
sample_idx: ${diagnostics.plot.sample_idx}
parameters: ${diagnostics.plot.parameters}
every_n_batches: ${diagnostics.plot.frequency.batch}
precip_and_related_fields: ["tp", "cp"] # Optional: specify precip fields for special histogram treatment
- _target_: anemoi.training.diagnostics.callbacks.plot_ens.GraphTrainableFeaturesPlot
every_n_epochs: ${diagnostics.plot.frequency.epoch}

# Overloaded PlotSample will return the plots for the first ensemble member
# - _target_: anemoi.training.diagnostics.callbacks.plot_ens.PlotSample
# sample_idx: ${diagnostics.plot.sample_idx}
# per_sample: 6
# parameters: ${diagnostics.plot.parameters}
# every_n_batches: ${diagnostics.plot.frequency.batch}
# accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
# precip_and_related_fields: ["tp", "cp"] # Optional: specify precip fields
38 changes: 38 additions & 0 deletions training/tests/integration/config/test_lam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,41 @@ data:
- "cos_local_time"
- "sin_julian_day"
- "sin_local_time"

diagnostics:
plot:
# Parameters to plot
parameters:
- tp
callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot
every_n_epochs: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp]
every_n_batches: ${diagnostics.plot.frequency.batch}
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}
every_n_batches: ${diagnostics.plot.frequency.batch}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
colormaps: ${diagnostics.plot.colormaps}
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# every_n_batches: 100 # Override for batch frequency
# min_delta: 0.01 # Minimum distance between two consecutive points
sample_idx: ${diagnostics.plot.sample_idx}
every_n_batches: ${diagnostics.plot.frequency.batch}
parameters:
- tp
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram
sample_idx: ${diagnostics.plot.sample_idx}
every_n_batches: ${diagnostics.plot.frequency.batch}
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
parameters:
- tp
39 changes: 39 additions & 0 deletions training/tests/integration/config/test_stretched.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,42 @@ training:
scalers:
node_weights:
weight_frac_of_total: 0.25


diagnostics:
plot:
# Parameters to plot
parameters:
- tp
callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot
every_n_epochs: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp]
every_n_batches: ${diagnostics.plot.frequency.batch}
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}
every_n_batches: ${diagnostics.plot.frequency.batch}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
colormaps: ${diagnostics.plot.colormaps}
# - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# # every_n_batches: 100 # Override for batch frequency
# # min_delta: 0.01 # Minimum distance between two consecutive points
# sample_idx: ${diagnostics.plot.sample_idx}
# every_n_batches: ${diagnostics.plot.frequency.batch}
# parameters:
# - tp
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram
sample_idx: ${diagnostics.plot.sample_idx}
every_n_batches: ${diagnostics.plot.frequency.batch}
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
parameters:
- tp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dataloader:
diagnostics:
plot:
callbacks: []
asynchronous: False # Whether to plot asynchronously
log:
wandb:
enabled: False
Expand Down
Loading