From 6b297ef015cf75fd791d3d1b76b467d8c3d6e5de Mon Sep 17 00:00:00 2001 From: Myles Stapelberg Date: Sat, 17 May 2025 16:30:29 -0400 Subject: [PATCH 1/7] updated ASE_to_VV to test FrechetCellFilter as well and compare directly to ASE --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 684 ++++++++++++++---- 1 file changed, 551 insertions(+), 133 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 1c85eb54..8a3124b3 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -10,15 +10,20 @@ import os import time +from typing import Literal import numpy as np import torch from ase.build import bulk +from ase.optimize import FIRE as ASEFIRE +from ase.filters import FrechetCellFilter from mace.calculators.foundations_models import mace_mp +from mace.calculators.foundations_models import mace_mp as mace_mp_calculator_for_ase +import matplotlib.pyplot as plt import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import fire +from torch_sim.optimizers import fire, frechet_cell_fire, GDState from torch_sim.state import SimState @@ -38,6 +43,7 @@ # Number of steps to run max_iterations = 10 if os.getenv("CI") else 500 supercell_scale = (1, 1, 1) if os.getenv("CI") else (3, 2, 2) +ase_max_optimizer_steps = max_iterations * 10 # Max steps for each individual ASE optimization run # Set random seed for reproducibility rng = np.random.default_rng(seed=0) @@ -113,151 +119,563 @@ def run_optimization( - initial_state: SimState, md_flavor: str, force_tol: float = 0.05 + initial_state: SimState, + optimizer_type: Literal["torch_sim", "ase"], + # For torch_sim: + ts_md_flavor: Literal["vv_fire", "ase_fire"] | None = None, + ts_use_frechet: bool = False, # To decide between fire() and frechet_cell_fire() + # For ASE: + ase_use_frechet_filter: bool = False, + # Common: + force_tol: float = 0.05, ) -> tuple[torch.Tensor, SimState]: - """Runs FIRE optimization and returns convergence steps.""" - print(f"\n--- Running optimization with MD Flavor: {md_flavor} ---") - start_time = time.perf_counter() - - # Re-initialize state and optimizer for this run - init_fn, update_fn = fire( - model=model, - md_flavor=md_flavor, - ) - fire_state = init_fn(initial_state.clone()) # Use a clone to start fresh - - batcher = ts.InFlightAutoBatcher( - model=model, - memory_scales_with="n_atoms", - max_memory_scaler=1000, - max_iterations=max_iterations, # Increased max iterations - return_indices=True, # Ensure indices are returned - ) - - batcher.load_states(fire_state) - - total_structures = fire_state.n_batches - # Initialize convergence steps tensor (-1 means not converged yet) - convergence_steps = torch.full( - (total_structures,), -1, dtype=torch.long, device=device - ) - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) - - converged_tensor_global = torch.zeros( - total_structures, dtype=torch.bool, device=device - ) - global_step = 0 - all_converged_states = [] # Initialize list to store completed states - convergence_tensor_for_batcher = None # Initialize convergence tensor for batcher - - # Keep track of the last valid state for final collection - last_active_state = fire_state - - while True: # Loop until batcher indicates completion - # Get the next batch, passing the convergence status - result = batcher.next_batch(last_active_state, convergence_tensor_for_batcher) - - fire_state, converged_states_from_batcher, current_indices_list = result - all_converged_states.extend( - converged_states_from_batcher - ) # Add newly completed states - - if fire_state is None: # No more active states - print("All structures converged or batcher reached max iterations.") - break - - last_active_state = fire_state # Store the current active state - - # Get the original indices of the current active batch as a tensor - current_indices = torch.tensor( - current_indices_list, dtype=torch.long, device=device + """Runs optimization and returns convergence steps and final state.""" + if optimizer_type == "torch_sim": + assert ts_md_flavor is not None, "ts_md_flavor must be provided for torch_sim" + print( + f"\n--- Running Torch-Sim optimization: flavor={ts_md_flavor}, " + f"frechet_cell_opt={ts_use_frechet}, force_tol={force_tol} ---" ) - - # Optimize the current batch - steps_this_round = 10 - for _ in range(steps_this_round): - fire_state = update_fn(fire_state) - global_step += steps_this_round # Increment global step count - - # Check convergence *within the active batch* - convergence_tensor_for_batcher = convergence_fn(fire_state, None) - - # Update global convergence status and steps - # Identify structures in this batch that just converged - newly_converged_mask_local = convergence_tensor_for_batcher & ( - convergence_steps[current_indices] == -1 + start_time = time.perf_counter() + + if ts_use_frechet: + # Uses frechet_cell_fire for combined cell and position optimization + init_fn_opt, update_fn_opt = frechet_cell_fire( + model=model, md_flavor=ts_md_flavor + ) + else: + # Uses fire for position-only optimization + init_fn_opt, update_fn_opt = fire(model=model, md_flavor=ts_md_flavor) + + opt_state = init_fn_opt(initial_state.clone()) + + batcher = ts.InFlightAutoBatcher( + model=model, # The MaceModel wrapper + memory_scales_with="n_atoms", + max_memory_scaler=1000, + max_iterations=max_iterations, + return_indices=True, ) - converged_indices_global = current_indices[newly_converged_mask_local] - - if converged_indices_global.numel() > 0: - # Mark convergence step - convergence_steps[converged_indices_global] = global_step - converged_tensor_global[converged_indices_global] = True - converged_indices = converged_indices_global.tolist() - - total_converged = converged_tensor_global.sum().item() / total_structures - print(f"{global_step=}: {converged_indices=}, {total_converged=:.2%}") - - # Optional: Print progress - if global_step % 50 == 0: # Reduced frequency - total_converged = converged_tensor_global.sum().item() / total_structures - active_structures = fire_state.n_batches if fire_state else 0 - print(f"{global_step=}: {active_structures=}, {total_converged=:.2%}") + batcher.load_states(opt_state) - # After the loop, collect any remaining states that were active in the last batch - # result[1] contains states completed *before* the last next_batch call. - # We need the states that were active *in* the last batch returned by next_batch - # If fire_state was the last active state, we might need to add it if batcher didn't - # mark it complete. However, restore_original_order should handle all collected states - # correctly. + total_structures = opt_state.n_batches + convergence_steps = torch.full( + (total_structures,), -1, dtype=torch.long, device=device + ) + convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + converged_tensor_global = torch.zeros( + total_structures, dtype=torch.bool, device=device + ) + global_step = 0 + all_converged_states = [] + convergence_tensor_for_batcher = None + last_active_state = opt_state + + while True: + result = batcher.next_batch( + last_active_state, convergence_tensor_for_batcher + ) + opt_state, converged_states_from_batcher, current_indices_list = result + all_converged_states.extend(converged_states_from_batcher) + + if opt_state is None: + print("All structures converged or batcher reached max iterations.") + break + + last_active_state = opt_state + current_indices = torch.tensor( + current_indices_list, dtype=torch.long, device=device + ) + + steps_this_round = 10 + for _ in range(steps_this_round): + opt_state = update_fn_opt(opt_state) + global_step += steps_this_round + + convergence_tensor_for_batcher = convergence_fn(opt_state, None) + newly_converged_mask_local = convergence_tensor_for_batcher & ( + convergence_steps[current_indices] == -1 + ) + converged_indices_global = current_indices[newly_converged_mask_local] + + if converged_indices_global.numel() > 0: + convergence_steps[converged_indices_global] = global_step + converged_tensor_global[converged_indices_global] = True + total_converged_frac = converged_tensor_global.sum().item() / total_structures + print( + f"{global_step=}: Converged indices {converged_indices_global.tolist()}, " + f"Total converged: {total_converged_frac:.2%}" + ) + + if global_step % 50 == 0: + total_converged_frac = converged_tensor_global.sum().item() / total_structures + active_structures = opt_state.n_batches if opt_state else 0 + print( + f"{global_step=}: Active structures: {active_structures}, " + f"Total converged: {total_converged_frac:.2%}" + ) + + final_states_list = batcher.restore_original_order(all_converged_states) + final_state_concatenated = ts.concatenate_states(final_states_list) + end_time = time.perf_counter() + print( + f"Finished Torch-Sim ({ts_md_flavor}, frechet={ts_use_frechet}) in " + f"{end_time - start_time:.2f} seconds." + ) + return convergence_steps, final_state_concatenated - # Restore original order and concatenate - final_states_list = batcher.restore_original_order(all_converged_states) - final_state_concatenated = ts.concatenate_states(final_states_list) + elif optimizer_type == "ase": + print( + f"\n--- Running ASE optimization: frechet_filter={ase_use_frechet_filter}, " + f"force_tol={force_tol} ---" + ) + start_time = time.perf_counter() + + individual_initial_states = initial_state.split() + num_structures = len(individual_initial_states) + final_ase_atoms_list = [] + convergence_steps_list = [] + + for i, single_sim_state in enumerate(individual_initial_states): + print(f"Optimizing structure {i+1}/{num_structures} with ASE...") + ase_atoms_orig = ts.io.state_to_atoms(single_sim_state)[0] + + ase_calc_instance = mace_mp_calculator_for_ase( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=str(dtype).split('.')[-1], + ) + ase_atoms_orig.calc = ase_calc_instance + + optim_target_atoms = ase_atoms_orig + if ase_use_frechet_filter: + print(f"Applying FrechetCellFilter to structure {i+1}") + optim_target_atoms = FrechetCellFilter(ase_atoms_orig) + + dyn = ASEFIRE(optim_target_atoms, trajectory=None, logfile=None) + + try: + dyn.run(fmax=force_tol, steps=ase_max_optimizer_steps) + if dyn.converged(): + convergence_steps_list.append(dyn.nsteps) + print(f"ASE structure {i+1} converged in {dyn.nsteps} steps.") + else: + print( + f"ASE optimization for structure {i+1} did not converge within " + f"{ase_max_optimizer_steps} steps. Steps taken: {dyn.nsteps}." + ) + convergence_steps_list.append(-1) + except Exception as e: + print(f"ASE optimization failed for structure {i+1}: {e}") + convergence_steps_list.append(-1) + + final_ase_atoms_list.append(optim_target_atoms.atoms if ase_use_frechet_filter else ase_atoms_orig) + + # Convert list of final ASE atoms objects back to a base SimState first + # to easily get positions, cell, etc. + # However, ts.io.atoms_to_state might not preserve all attributes needed for GDState directly. + # It's better to extract all required components directly from final_ase_atoms_list. + + all_positions = [] + all_masses = [] + all_atomic_numbers = [] + all_cells = [] + all_batches_for_gd = [] + final_energies_ase = [] + final_forces_ase_tensors = [] # List to store force tensors + + current_atom_offset = 0 + for batch_idx, ats_final in enumerate(final_ase_atoms_list): + all_positions.append(torch.tensor(ats_final.get_positions(), device=device, dtype=dtype)) + all_masses.append(torch.tensor(ats_final.get_masses(), device=device, dtype=dtype)) + all_atomic_numbers.append(torch.tensor(ats_final.get_atomic_numbers(), device=device, dtype=torch.long)) + # ASE cell is row-vector, SimState expects column-vector + all_cells.append(torch.tensor(ats_final.get_cell().array.T, device=device, dtype=dtype)) + + num_atoms_in_current = len(ats_final) + all_batches_for_gd.append(torch.full((num_atoms_in_current,), batch_idx, device=device, dtype=torch.long)) + current_atom_offset += num_atoms_in_current + + try: + if ats_final.calc is None: + print(f"Re-attaching ASE calculator for final energy/forces for structure {batch_idx}.") + temp_calc = mace_mp_calculator_for_ase( + model=MaceUrls.mace_mpa_medium, device=device, default_dtype=str(dtype).split('.')[-1] + ) + ats_final.calc = temp_calc + final_energies_ase.append(ats_final.get_potential_energy()) + final_forces_ase_tensors.append(torch.tensor(ats_final.get_forces(), device=device, dtype=dtype)) + except Exception as e: + print(f"Could not get final energy/forces for an ASE structure {batch_idx}: {e}") + final_energies_ase.append(float('nan')) + # Append a zero tensor of appropriate shape if forces fail, or handle error + # For GDState, forces are required. If any structure fails, GDState creation might fail. + # We need to ensure all_positions, etc. are also correctly populated even on failure. + # For now, let's assume if energy fails, forces might also, and GDState might be problematic. + # A robust solution would be to skip failed structures or return None. + # For now, let's make forces a zero tensor of expected shape if it fails. + if all_positions and len(all_positions[-1]) > 0: + final_forces_ase_tensors.append(torch.zeros_like(all_positions[-1])) + else: # Cannot determine shape, this path is problematic + final_forces_ase_tensors.append(torch.empty((0,3), device=device, dtype=dtype)) + + + if not all_positions: # If all optimizations failed early + print("Warning: No successful ASE structures to form GDState.") + return torch.tensor(convergence_steps_list, dtype=torch.long, device=device), None + + + # Concatenate all parts + concatenated_positions = torch.cat(all_positions, dim=0) + concatenated_masses = torch.cat(all_masses, dim=0) + concatenated_atomic_numbers = torch.cat(all_atomic_numbers, dim=0) + concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (N_batch, 3, 3) + concatenated_batch_indices = torch.cat(all_batches_for_gd, dim=0) + + concatenated_energies = torch.tensor(final_energies_ase, device=device, dtype=dtype) + concatenated_forces = torch.cat(final_forces_ase_tensors, dim=0) + + # Check for NaN energies which might cause issues + if torch.isnan(concatenated_energies).any(): + print("Warning: NaN values found in final ASE energies. GDState energy tensor will contain NaNs.") + # Consider replacing NaNs if GDState or subsequent ops can't handle them: + # concatenated_energies = torch.nan_to_num(concatenated_energies, nan=0.0) # Example replacement + + # Create GDState instance + # pbc is global, taken from initial_state + final_state_as_gd = GDState( + positions=concatenated_positions, + masses=concatenated_masses, + cell=concatenated_cells, + pbc=initial_state.pbc, # Assuming pbc is constant and global + atomic_numbers=concatenated_atomic_numbers, + batch=concatenated_batch_indices, + energy=concatenated_energies, + forces=concatenated_forces, + ) + + convergence_steps = torch.tensor(convergence_steps_list, dtype=torch.long, device=device) - end_time = time.perf_counter() - print(f"Finished {md_flavor} in {end_time - start_time:.2f} seconds.") - # Return both convergence steps and the final state object - return convergence_steps, final_state_concatenated + end_time = time.perf_counter() + print( + f"Finished ASE optimization (frechet_filter={ase_use_frechet_filter}) " + f"in {end_time - start_time:.2f} seconds." + ) + return convergence_steps, final_state_as_gd + else: + raise ValueError(f"Unknown optimizer_type: {optimizer_type}") # --- Main Script --- force_tol = 0.05 -# Run with ase_fire -ase_steps, ase_final_state = run_optimization( - state.clone(), "ase_fire", force_tol=force_tol -) -# Run with vv_fire -vv_steps, vv_final_state = run_optimization(state.clone(), "vv_fire", force_tol=force_tol) +# Configurations to test +configs_to_run = [ + { + "name": "torch-sim VV-FIRE (PosOnly)", + "type": "torch_sim", "ts_md_flavor": "vv_fire", "ts_use_frechet": False, + }, + { + "name": "torch-sim ASE-FIRE (PosOnly)", + "type": "torch_sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": False, + }, + { + "name": "torch-sim VV-FIRE (Frechet Cell)", + "type": "torch_sim", "ts_md_flavor": "vv_fire", "ts_use_frechet": True, + }, + { + "name": "torch-sim ASE-FIRE (Frechet Cell)", + "type": "torch_sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": True, + }, + { + "name": "ASE FIRE (Native, CellOpt)", # Will optimize cell if stress is available + "type": "ase", "ase_use_frechet_filter": False, + }, + { + "name": "ASE FIRE (Native Frechet Filter, CellOpt)", + "type": "ase", "ase_use_frechet_filter": True, + }, +] + +results_all = {} + +for config_run in configs_to_run: + print(f"\n\nStarting configuration: {config_run['name']}") + # Get relevant params, providing defaults where necessary for the run_optimization call + optimizer_type_val = config_run["type"] + ts_md_flavor_val = config_run.get("ts_md_flavor") # Will be None for ASE type, handled by assert + ts_use_frechet_val = config_run.get("ts_use_frechet", False) + ase_use_frechet_filter_val = config_run.get("ase_use_frechet_filter", False) + + steps, final_state_opt = run_optimization( + initial_state=state.clone(), # Use a fresh clone for each run + optimizer_type=optimizer_type_val, + ts_md_flavor=ts_md_flavor_val, + ts_use_frechet=ts_use_frechet_val, + ase_use_frechet_filter=ase_use_frechet_filter_val, + force_tol=force_tol, + ) + results_all[config_run["name"]] = {"steps": steps, "final_state": final_state_opt} -print("\n--- Comparison ---") + +print("\n\n--- Overall Comparison ---") print(f"{force_tol=:.2f} eV/Å") +print(f"Initial energies: {[f'{e.item():.3f}' for e in initial_energies]} eV") -# Calculate Mean Position Displacements -ase_final_states_list = ase_final_state.split() -vv_final_states_list = vv_final_state.split() -mean_displacements = [] -for ase_state, vv_state in zip(ase_final_states_list, vv_final_states_list, strict=True): - ase_pos = ase_state.positions - ase_state.positions.mean(dim=0) - vv_pos = vv_state.positions - vv_state.positions.mean(dim=0) - displacement = torch.norm(ase_pos - vv_pos, dim=1) - mean_disp = torch.mean(displacement).item() - mean_displacements.append(mean_disp) +for name, result_data in results_all.items(): + final_state_res = result_data["final_state"] + steps_res = result_data["steps"] + print(f"\nResults for: {name}") + if final_state_res is not None and hasattr(final_state_res, 'energy') and final_state_res.energy is not None: + energy_str = [f'{e.item():.3f}' for e in final_state_res.energy] + print(f" Final energies: {energy_str} eV") + else: + print(f" Final energies: Not available or state is None") + print(f" Convergence steps: {steps_res.tolist()}") + + not_converged_indices = torch.where(steps_res == -1)[0].tolist() + if not_converged_indices: + print(f" Did not converge for structure indices: {not_converged_indices}") + +# Mean Displacement Comparisons +comparison_pairs = [ + ("torch-sim ASE-FIRE (PosOnly)", "ASE FIRE (Native, CellOpt)"), # Note: one is pos-only, other cell-opt + ("torch-sim ASE-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), + ("torch-sim VV-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), + ("torch-sim VV-FIRE (PosOnly)", "torch-sim ASE-FIRE (PosOnly)"), # Original comparison +] + +for name1, name2 in comparison_pairs: + if name1 in results_all and name2 in results_all: + state1 = results_all[name1]["final_state"] + state2 = results_all[name2]["final_state"] + + if state1 is None or state2 is None: + print(f"\nCannot compare {name1} and {name2}, one or both states are None.") + continue + + state1_list = state1.split() + + state2_list = state2.split() + + if len(state1_list) != len(state2_list): + print(f"\nCannot compare {name1} and {name2}, different number of structures.") + continue + + mean_displacements = [] + for s1, s2 in zip(state1_list, state2_list, strict=True): + if s1.n_atoms == 0 or s2.n_atoms == 0 : # Handle empty states if they occur + mean_displacements.append(float('nan')) + continue + pos1_centered = s1.positions - s1.positions.mean(dim=0, keepdim=True) + pos2_centered = s2.positions - s2.positions.mean(dim=0, keepdim=True) + if pos1_centered.shape != pos2_centered.shape: + print(f"Warning: Shape mismatch for {name1} vs {name2} in structure. Skipping displacement calc.") + mean_displacements.append(float('nan')) + continue + displacement = torch.norm(pos1_centered - pos2_centered, dim=1) + mean_disp = torch.mean(displacement).item() + mean_displacements.append(mean_disp) + + print(f"\nMean Disp ({name1} vs {name2}): {[f'{d:.4f}' for d in mean_displacements]} Å") + else: + print(f"\nSkipping displacement comparison for ({name1} vs {name2}), one or both results missing.") + + +# --- Plotting Results --- + +# Names for the 5 structures for plotting labels +structure_names = [ats.get_chemical_formula() for ats in atoms_list] +# Make them more concise if needed: +structure_names = ["Si_bulk", "Cu_bulk", "Fe_bulk", "Si_vac", "Cu_vac"] # Example concise names +num_structures_plot = len(structure_names) + + +# --- Plot 1: Convergence Steps (Multi-bar per structure) --- +plot_methods_fig1 = list(results_all.keys()) +num_methods_fig1 = len(plot_methods_fig1) +steps_data_fig1 = np.zeros((num_structures_plot, num_methods_fig1)) # rows: structures, cols: methods + +for method_idx, method_name in enumerate(plot_methods_fig1): + result_data = results_all[method_name] + if result_data["final_state"] is None: + steps_data_fig1[:, method_idx] = np.nan # Mark all as NaN for this method + print(f"Plot1: Skipping steps for {method_name} as final_state is None.") + continue + + steps_tensor = result_data["steps"].cpu().numpy() + # Replace -1 (not converged) with a high value for plotting, or handle differently + # For now, let's use a penalty (e.g., max_iterations_overall + buffer) + # or keep as -1 and let user interpret. For bar plot, positive values are better. + # Let's use np.nan for non-converged for now, and plot what did converge. + # Or, use ase_max_optimizer_steps as a cap if not converged. + # Let's use the actual steps, and if -1, plot it as a very high bar or special marker. + # For a bar chart, a common approach is to cap it or show it differently. + # We will cap at ase_max_optimizer_steps + a bit if -1 + penalty_steps = ase_max_optimizer_steps + 100 + steps_plot_values = np.where(steps_tensor == -1, penalty_steps, steps_tensor) + + if len(steps_plot_values) == num_structures_plot: + steps_data_fig1[:, method_idx] = steps_plot_values + else: + print(f"Warning: Mismatch in number of structures for steps in {method_name}. Expected {num_structures_plot}, got {len(steps_plot_values)}") + steps_data_fig1[:, method_idx] = np.nan + + +fig1, ax1 = plt.subplots(figsize=(15, 8)) +x_fig1 = np.arange(num_structures_plot) # x locations for the groups +width_fig1 = 0.8 / num_methods_fig1 # width of the bars + +rects_all_fig1 = [] +for i in range(num_methods_fig1): + # Offset each bar in the group + rects = ax1.bar(x_fig1 - 0.4 + (i + 0.5) * width_fig1, steps_data_fig1[:, i], width_fig1, label=plot_methods_fig1[i]) + rects_all_fig1.append(rects) + # Add text for -1 (non-converged) if we plot them as penalty + for bar_idx, bar_val in enumerate(steps_data_fig1[:, i]): + original_step_val = results_all[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] + if original_step_val == -1: + ax1.text(rects[bar_idx].get_x() + rects[bar_idx].get_width() / 2., + rects[bar_idx].get_height() - 10, # Position slightly below top of penalty bar + 'NC', ha='center', va='top', color='white', fontsize=7, weight='bold') + + +ax1.set_ylabel('Convergence Steps (NC = Not Converged, shown at penalty)') +ax1.set_xlabel('Structure') +ax1.set_title('Convergence Steps per Structure and Method') +ax1.set_xticks(x_fig1) +ax1.set_xticklabels(structure_names, rotation=45, ha="right") +ax1.legend(title="Optimization Method", bbox_to_anchor=(1.05, 1), loc='upper left') +ax1.grid(True, axis='y', linestyle='--', alpha=0.7) +plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust rect to make space for legend + + +# --- Plot 2: Average Final Energy Difference from Baselines --- +baseline_ase_pos_only = "ASE FIRE (Native, CellOpt)" +baseline_ase_frechet = "ASE FIRE (Native Frechet Filter, CellOpt)" +avg_energy_diffs_fig2 = [] +plot_names_fig2 = [] + +# Ensure baselines exist and have data +baseline_pos_only_data = results_all.get(baseline_ase_pos_only) +baseline_frechet_data = results_all.get(baseline_ase_frechet) + +for name, result_data in results_all.items(): + if result_data["final_state"] is None: + print(f"Plot2: Skipping energy diff for {name} as final_state is None.") + continue + + plot_names_fig2.append(name) + current_energies = result_data["final_state"].energy.cpu().numpy() + + chosen_baseline_energies = None + is_baseline_self = False + if "torch-sim" in name: + if "PosOnly" in name: + if baseline_pos_only_data and baseline_pos_only_data["final_state"] is not None: + chosen_baseline_energies = baseline_pos_only_data["final_state"].energy.cpu().numpy() + elif "Frechet Cell" in name: + if baseline_frechet_data and baseline_frechet_data["final_state"] is not None: + chosen_baseline_energies = baseline_frechet_data["final_state"].energy.cpu().numpy() + elif name == baseline_ase_pos_only or name == baseline_ase_frechet: + avg_energy_diffs_fig2.append(0.0) # Difference to self is 0 + is_baseline_self = True # Flag to handle text for baseline bars + # continue # Continue was here, but we need to plot the baseline bar itself + + if not is_baseline_self: # Only calculate diff if not a baseline comparing to itself + if chosen_baseline_energies is not None: + if current_energies.shape == chosen_baseline_energies.shape: + energy_diff = np.mean(current_energies - chosen_baseline_energies) + avg_energy_diffs_fig2.append(energy_diff) + else: + avg_energy_diffs_fig2.append(np.nan) + print(f"Shape mismatch for energy comparison: {name} vs its baseline") + else: + # If no appropriate baseline, or baseline data is missing + print(f"Plot2: No appropriate baseline for {name} or baseline data missing. Setting energy diff to NaN.") + avg_energy_diffs_fig2.append(np.nan) + +fig2, ax2 = plt.subplots(figsize=(12, 7)) +bars_fig2 = ax2.bar(plot_names_fig2, avg_energy_diffs_fig2, color='lightcoral') # Store the bars +ax2.set_ylabel('Avg. Final Energy Diff. from Corresponding ASE Baseline (eV)') +ax2.set_xlabel('Optimization Method') +ax2.set_title('Average Final Energy Difference from ASE Baselines') +ax2.axhline(0, color='black', linewidth=0.8, linestyle='--') + +# Add text labels on top of bars for Figure 2 +for bar in bars_fig2: + yval = bar.get_height() + if not np.isnan(yval): # Only add text if not NaN + # Adjust text position based on whether the bar is positive or negative + text_y_offset = 0.001 if yval >= 0 else -0.005 # Small offset for visibility + va_align = 'bottom' if yval >=0 else 'top' + ax2.text(bar.get_x() + bar.get_width()/2.0, yval + text_y_offset, + f"{yval:.3f}", ha='center', va=va_align, fontsize=8, color='black') + +plt.xticks(rotation=45, ha="right") +plt.tight_layout() + + +# --- Plot 3: Average Mean Displacement from ASE Counterparts --- +disp_plot_data_fig3 = {} +comparison_pairs_plot3 = [ + ("torch-sim ASE-FIRE (PosOnly)", baseline_ase_pos_only, "TS ASE PosOnly vs ASE Native"), + ("torch-sim VV-FIRE (PosOnly)", baseline_ase_pos_only, "TS VV PosOnly vs ASE Native"), + ("torch-sim ASE-FIRE (Frechet Cell)", baseline_ase_frechet, "TS ASE Frechet vs ASE Frechet"), + ("torch-sim VV-FIRE (Frechet Cell)", baseline_ase_frechet, "TS VV Frechet vs ASE Frechet"), +] + +for ts_method_name, ase_method_name, plot_label in comparison_pairs_plot3: + if ts_method_name in results_all and ase_method_name in results_all: + state1_data = results_all[ts_method_name] + state2_data = results_all[ase_method_name] + + if state1_data["final_state"] is None or state2_data["final_state"] is None: + print(f"Plot3: Skipping displacement for {plot_label} due to missing state data.") + disp_plot_data_fig3[plot_label] = np.nan + continue + + state1_list = state1_data["final_state"].split() + state2_list = state2_data["final_state"].split() + + if len(state1_list) != len(state2_list): + print(f"Plot3: Structure count mismatch for {plot_label}.") + disp_plot_data_fig3[plot_label] = np.nan + continue + + mean_displacements_current_pair = [] + for s1, s2 in zip(state1_list, state2_list, strict=True): + if s1.n_atoms == 0 or s2.n_atoms == 0 or s1.n_atoms != s2.n_atoms: + mean_displacements_current_pair.append(np.nan) + continue + pos1_centered = s1.positions - s1.positions.mean(dim=0, keepdim=True) + pos2_centered = s2.positions - s2.positions.mean(dim=0, keepdim=True) + displacement = torch.norm(pos1_centered - pos2_centered, dim=1) + mean_disp = torch.mean(displacement).item() + mean_displacements_current_pair.append(mean_disp) + + if mean_displacements_current_pair: + avg_disp = np.nanmean(mean_displacements_current_pair) + disp_plot_data_fig3[plot_label] = avg_disp + else: + disp_plot_data_fig3[plot_label] = np.nan + else: + print(f"Plot3: Missing data for {ts_method_name} or {ase_method_name}.") + disp_plot_data_fig3[plot_label] = np.nan -print(f"Initial energies: {[f'{e.item():.3f}' for e in initial_energies]} eV") -print(f"Final ASE energies: {[f'{e.item():.3f}' for e in ase_final_state.energy]} eV") -print(f"Final VV energies: {[f'{e.item():.3f}' for e in vv_final_state.energy]} eV") -print(f"Mean Disp (ASE-VV): {[f'{d:.4f}' for d in mean_displacements]} Å") -print(f"Convergence steps (ASE FIRE): {ase_steps.tolist()}") -print(f"Convergence steps (VV FIRE): {vv_steps.tolist()}") - -# Identify structures that didn't converge -ase_not_converged = torch.where(ase_steps == -1)[0].tolist() -vv_not_converged = torch.where(vv_steps == -1)[0].tolist() - -if ase_not_converged: - print(f"ASE FIRE did not converge for indices: {ase_not_converged}") -if vv_not_converged: - print(f"VV FIRE did not converge for indices: {vv_not_converged}") +if disp_plot_data_fig3: + disp_methods_fig3 = list(disp_plot_data_fig3.keys()) + disp_values_fig3 = list(disp_plot_data_fig3.values()) + + fig3, ax3 = plt.subplots(figsize=(12, 8)) + ax3.bar(disp_methods_fig3, disp_values_fig3, color='mediumseagreen') + ax3.set_ylabel('Avg. Mean Atomic Displacement (Å) to ASE Counterpart') + ax3.set_xlabel('Comparison Pair') + ax3.set_title('Mean Displacement of Torch-Sim Methods to ASE Counterparts') + plt.xticks(rotation=45, ha="right") + plt.tight_layout() +else: + print("No displacement data to plot for Figure 3.") + +plt.show() From 631175fda830c28fa701fff06a1f2697ad3ad253 Mon Sep 17 00:00:00 2001 From: Myles Stapelberg Date: Sat, 17 May 2025 16:49:18 -0400 Subject: [PATCH 2/7] updated script to compute all 6 structures --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 219 +++++++++++------- 1 file changed, 132 insertions(+), 87 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 8a3124b3..b08e591a 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -94,7 +94,7 @@ # Create a list of our atomic systems -atoms_list = [si_dc, cu_dc, fe_dc, si_dc_vac, cu_dc_vac] +atoms_list = [si_dc, cu_dc, fe_dc, si_dc_vac, cu_dc_vac, fe_dc_vac] # Print structure information print(f"Silicon atoms: {len(si_dc)}") @@ -484,70 +484,66 @@ def run_optimization( # --- Plotting Results --- -# Names for the 5 structures for plotting labels -structure_names = [ats.get_chemical_formula() for ats in atoms_list] +# Names for the structures for plotting labels +original_structure_formulas = [ats.get_chemical_formula() for ats in atoms_list] # Make them more concise if needed: -structure_names = ["Si_bulk", "Cu_bulk", "Fe_bulk", "Si_vac", "Cu_vac"] # Example concise names +structure_names = ["Si_bulk", "Cu_bulk", "Fe_bulk", "Si_vac", "Cu_vac", "Fe_vac"] # Updated for 6 structures +if len(structure_names) != len(atoms_list): + print(f"Warning: Mismatch between custom structure_names ({len(structure_names)}) and atoms_list ({len(atoms_list)}). Using custom names.") num_structures_plot = len(structure_names) # --- Plot 1: Convergence Steps (Multi-bar per structure) --- plot_methods_fig1 = list(results_all.keys()) num_methods_fig1 = len(plot_methods_fig1) -steps_data_fig1 = np.zeros((num_structures_plot, num_methods_fig1)) # rows: structures, cols: methods +# Initialize with NaNs, so if a method fails completely, its bars are missing or clearly marked +steps_data_fig1 = np.full((num_structures_plot, num_methods_fig1), np.nan) for method_idx, method_name in enumerate(plot_methods_fig1): result_data = results_all[method_name] - if result_data["final_state"] is None: - steps_data_fig1[:, method_idx] = np.nan # Mark all as NaN for this method - print(f"Plot1: Skipping steps for {method_name} as final_state is None.") + if result_data["final_state"] is None or result_data["steps"] is None: + # steps_data_fig1[:, method_idx] = np.nan # Already initialized with NaN + print(f"Plot1: Skipping steps for {method_name} as final_state or steps is None.") continue steps_tensor = result_data["steps"].cpu().numpy() - # Replace -1 (not converged) with a high value for plotting, or handle differently - # For now, let's use a penalty (e.g., max_iterations_overall + buffer) - # or keep as -1 and let user interpret. For bar plot, positive values are better. - # Let's use np.nan for non-converged for now, and plot what did converge. - # Or, use ase_max_optimizer_steps as a cap if not converged. - # Let's use the actual steps, and if -1, plot it as a very high bar or special marker. - # For a bar chart, a common approach is to cap it or show it differently. - # We will cap at ase_max_optimizer_steps + a bit if -1 penalty_steps = ase_max_optimizer_steps + 100 steps_plot_values = np.where(steps_tensor == -1, penalty_steps, steps_tensor) if len(steps_plot_values) == num_structures_plot: steps_data_fig1[:, method_idx] = steps_plot_values - else: - print(f"Warning: Mismatch in number of structures for steps in {method_name}. Expected {num_structures_plot}, got {len(steps_plot_values)}") - steps_data_fig1[:, method_idx] = np.nan - - -fig1, ax1 = plt.subplots(figsize=(15, 8)) -x_fig1 = np.arange(num_structures_plot) # x locations for the groups -width_fig1 = 0.8 / num_methods_fig1 # width of the bars + elif len(steps_plot_values) > num_structures_plot: + print(f"Warning: More step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Truncating.") + steps_data_fig1[:, method_idx] = steps_plot_values[:num_structures_plot] + elif len(steps_plot_values) < num_structures_plot: + print(f"Warning: Fewer step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Padding with NaN.") + steps_data_fig1[:len(steps_plot_values), method_idx] = steps_plot_values + # The rest will remain NaN due to initialization + +fig1, ax1 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend +x_fig1 = np.arange(num_structures_plot) +width_fig1 = 0.8 / num_methods_fig1 rects_all_fig1 = [] for i in range(num_methods_fig1): - # Offset each bar in the group rects = ax1.bar(x_fig1 - 0.4 + (i + 0.5) * width_fig1, steps_data_fig1[:, i], width_fig1, label=plot_methods_fig1[i]) rects_all_fig1.append(rects) - # Add text for -1 (non-converged) if we plot them as penalty for bar_idx, bar_val in enumerate(steps_data_fig1[:, i]): - original_step_val = results_all[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] - if original_step_val == -1: - ax1.text(rects[bar_idx].get_x() + rects[bar_idx].get_width() / 2., - rects[bar_idx].get_height() - 10, # Position slightly below top of penalty bar - 'NC', ha='center', va='top', color='white', fontsize=7, weight='bold') - + if bar_idx < len(results_all[plot_methods_fig1[i]]["steps"]): # Check bounds + original_step_val = results_all[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] + if original_step_val == -1 and not np.isnan(bar_val): # Check if it was a penalty bar + ax1.text(rects[bar_idx].get_x() + rects[bar_idx].get_width() / 2., + rects[bar_idx].get_height() - 10, + 'NC', ha='center', va='top', color='white', fontsize=7, weight='bold') ax1.set_ylabel('Convergence Steps (NC = Not Converged, shown at penalty)') ax1.set_xlabel('Structure') ax1.set_title('Convergence Steps per Structure and Method') ax1.set_xticks(x_fig1) ax1.set_xticklabels(structure_names, rotation=45, ha="right") -ax1.legend(title="Optimization Method", bbox_to_anchor=(1.05, 1), loc='upper left') +ax1.legend(title="Optimization Method", bbox_to_anchor=(1.02, 1), loc='upper left') # Adjusted legend position ax1.grid(True, axis='y', linestyle='--', alpha=0.7) -plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust rect to make space for legend +plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend # --- Plot 2: Average Final Energy Difference from Baselines --- @@ -556,58 +552,101 @@ def run_optimization( avg_energy_diffs_fig2 = [] plot_names_fig2 = [] -# Ensure baselines exist and have data baseline_pos_only_data = results_all.get(baseline_ase_pos_only) baseline_frechet_data = results_all.get(baseline_ase_frechet) for name, result_data in results_all.items(): - if result_data["final_state"] is None: - print(f"Plot2: Skipping energy diff for {name} as final_state is None.") + if result_data["final_state"] is None or result_data["final_state"].energy is None: + print(f"Plot2: Skipping energy diff for {name} as final_state or energy is None.") + if name not in plot_names_fig2: plot_names_fig2.append(name) # Keep name for consistent bar count + avg_energy_diffs_fig2.append(np.nan) # Add NaN if data missing continue - plot_names_fig2.append(name) + # Ensure name is added if not already by a skip + if name not in plot_names_fig2: plot_names_fig2.append(name) + current_energies = result_data["final_state"].energy.cpu().numpy() chosen_baseline_energies = None is_baseline_self = False - if "torch-sim" in name: + processed_current_name = False + + if name == baseline_ase_pos_only or name == baseline_ase_frechet: + avg_energy_diffs_fig2.append(0.0) + is_baseline_self = True + processed_current_name = True + elif "torch-sim" in name: if "PosOnly" in name: - if baseline_pos_only_data and baseline_pos_only_data["final_state"] is not None: + if baseline_pos_only_data and baseline_pos_only_data["final_state"] is not None and baseline_pos_only_data["final_state"].energy is not None: chosen_baseline_energies = baseline_pos_only_data["final_state"].energy.cpu().numpy() elif "Frechet Cell" in name: - if baseline_frechet_data and baseline_frechet_data["final_state"] is not None: + if baseline_frechet_data and baseline_frechet_data["final_state"] is not None and baseline_frechet_data["final_state"].energy is not None: chosen_baseline_energies = baseline_frechet_data["final_state"].energy.cpu().numpy() - elif name == baseline_ase_pos_only or name == baseline_ase_frechet: - avg_energy_diffs_fig2.append(0.0) # Difference to self is 0 - is_baseline_self = True # Flag to handle text for baseline bars - # continue # Continue was here, but we need to plot the baseline bar itself - if not is_baseline_self: # Only calculate diff if not a baseline comparing to itself + if not is_baseline_self and not processed_current_name: if chosen_baseline_energies is not None: if current_energies.shape == chosen_baseline_energies.shape: energy_diff = np.mean(current_energies - chosen_baseline_energies) avg_energy_diffs_fig2.append(energy_diff) else: avg_energy_diffs_fig2.append(np.nan) - print(f"Shape mismatch for energy comparison: {name} vs its baseline") + print(f"Plot2: Shape mismatch for energy comparison: {name} vs its baseline. " + f"{current_energies.shape} vs {chosen_baseline_energies.shape}") else: - # If no appropriate baseline, or baseline data is missing print(f"Plot2: No appropriate baseline for {name} or baseline data missing. Setting energy diff to NaN.") avg_energy_diffs_fig2.append(np.nan) + elif not processed_current_name and name not in [n for n,v in zip(plot_names_fig2, avg_energy_diffs_fig2) if not np.isnan(v)] : # Handle cases not covered + print(f"Plot2: Fallback for {name}, setting energy diff to NaN.") + avg_energy_diffs_fig2.append(np.nan) + + +# Ensure plot_names_fig2 and avg_energy_diffs_fig2 have the same length +# This can happen if a name was added to plot_names_fig2 but its energy_diff calculation failed or was skipped. +# A more robust way is to build them in parallel. +final_plot_names_fig2 = [] +final_avg_energy_diffs_fig2 = [] +all_method_names_sorted = sorted(list(results_all.keys())) # Use a fixed order + +for name in all_method_names_sorted: + result_data = results_all[name] + final_plot_names_fig2.append(name) + if result_data["final_state"] is None or result_data["final_state"].energy is None: + final_avg_energy_diffs_fig2.append(np.nan) + continue + + current_energies = result_data["final_state"].energy.cpu().numpy() + energy_to_append = np.nan # Default to NaN + + if name == baseline_ase_pos_only or name == baseline_ase_frechet: + energy_to_append = 0.0 + elif "torch-sim" in name: + baseline_to_use_energies = None + if "PosOnly" in name: + if baseline_pos_only_data and baseline_pos_only_data["final_state"] is not None and baseline_pos_only_data["final_state"].energy is not None: + baseline_to_use_energies = baseline_pos_only_data["final_state"].energy.cpu().numpy() + elif "Frechet Cell" in name: + if baseline_frechet_data and baseline_frechet_data["final_state"] is not None and baseline_frechet_data["final_state"].energy is not None: + baseline_to_use_energies = baseline_frechet_data["final_state"].energy.cpu().numpy() + + if baseline_to_use_energies is not None: + if current_energies.shape == baseline_to_use_energies.shape: + energy_to_append = np.mean(current_energies - baseline_to_use_energies) + else: + print(f"Plot2: Shape mismatch for {name} ({current_energies.shape}) vs baseline ({baseline_to_use_energies.shape}).") + final_avg_energy_diffs_fig2.append(energy_to_append) + fig2, ax2 = plt.subplots(figsize=(12, 7)) -bars_fig2 = ax2.bar(plot_names_fig2, avg_energy_diffs_fig2, color='lightcoral') # Store the bars +bars_fig2 = ax2.bar(final_plot_names_fig2, final_avg_energy_diffs_fig2, color='lightcoral') ax2.set_ylabel('Avg. Final Energy Diff. from Corresponding ASE Baseline (eV)') ax2.set_xlabel('Optimization Method') ax2.set_title('Average Final Energy Difference from ASE Baselines') ax2.axhline(0, color='black', linewidth=0.8, linestyle='--') -# Add text labels on top of bars for Figure 2 for bar in bars_fig2: yval = bar.get_height() - if not np.isnan(yval): # Only add text if not NaN - # Adjust text position based on whether the bar is positive or negative - text_y_offset = 0.001 if yval >= 0 else -0.005 # Small offset for visibility + if not np.isnan(yval): + text_y_offset = 0.001 if yval >= 0 else -0.005 va_align = 'bottom' if yval >=0 else 'top' ax2.text(bar.get_x() + bar.get_width()/2.0, yval + text_y_offset, f"{yval:.3f}", ha='center', va=va_align, fontsize=8, color='black') @@ -616,66 +655,72 @@ def run_optimization( plt.tight_layout() -# --- Plot 3: Average Mean Displacement from ASE Counterparts --- -disp_plot_data_fig3 = {} -comparison_pairs_plot3 = [ +# --- Plot 3: Mean Displacement from ASE Counterparts (Multi-bar per structure) --- +comparison_pairs_plot3_defs = [ # (ts_method_name, ase_method_name, short_label_for_legend) ("torch-sim ASE-FIRE (PosOnly)", baseline_ase_pos_only, "TS ASE PosOnly vs ASE Native"), ("torch-sim VV-FIRE (PosOnly)", baseline_ase_pos_only, "TS VV PosOnly vs ASE Native"), ("torch-sim ASE-FIRE (Frechet Cell)", baseline_ase_frechet, "TS ASE Frechet vs ASE Frechet"), ("torch-sim VV-FIRE (Frechet Cell)", baseline_ase_frechet, "TS VV Frechet vs ASE Frechet"), ] +num_comparison_pairs_plot3 = len(comparison_pairs_plot3_defs) +# rows: structures, cols: comparison_pair +disp_data_fig3 = np.full((num_structures_plot, num_comparison_pairs_plot3), np.nan) +legend_labels_fig3 = [] -for ts_method_name, ase_method_name, plot_label in comparison_pairs_plot3: +for pair_idx, (ts_method_name, ase_method_name, plot_label) in enumerate(comparison_pairs_plot3_defs): + legend_labels_fig3.append(plot_label) if ts_method_name in results_all and ase_method_name in results_all: state1_data = results_all[ts_method_name] state2_data = results_all[ase_method_name] if state1_data["final_state"] is None or state2_data["final_state"] is None: print(f"Plot3: Skipping displacement for {plot_label} due to missing state data.") - disp_plot_data_fig3[plot_label] = np.nan + # Data remains NaN continue state1_list = state1_data["final_state"].split() state2_list = state2_data["final_state"].split() - if len(state1_list) != len(state2_list): - print(f"Plot3: Structure count mismatch for {plot_label}.") - disp_plot_data_fig3[plot_label] = np.nan + if len(state1_list) != len(state2_list) or len(state1_list) != num_structures_plot : + print(f"Plot3: Structure count mismatch for {plot_label}. Expected {num_structures_plot}, got S1:{len(state1_list)}, S2:{len(state2_list)}") + # Data remains NaN continue - mean_displacements_current_pair = [] - for s1, s2 in zip(state1_list, state2_list, strict=True): + mean_displacements_for_this_pair = [] + for s_idx, (s1, s2) in enumerate(zip(state1_list, state2_list, strict=True)): if s1.n_atoms == 0 or s2.n_atoms == 0 or s1.n_atoms != s2.n_atoms: - mean_displacements_current_pair.append(np.nan) + mean_displacements_for_this_pair.append(np.nan) continue pos1_centered = s1.positions - s1.positions.mean(dim=0, keepdim=True) pos2_centered = s2.positions - s2.positions.mean(dim=0, keepdim=True) displacement = torch.norm(pos1_centered - pos2_centered, dim=1) mean_disp = torch.mean(displacement).item() - mean_displacements_current_pair.append(mean_disp) + mean_displacements_for_this_pair.append(mean_disp) - if mean_displacements_current_pair: - avg_disp = np.nanmean(mean_displacements_current_pair) - disp_plot_data_fig3[plot_label] = avg_disp - else: - disp_plot_data_fig3[plot_label] = np.nan + if len(mean_displacements_for_this_pair) == num_structures_plot: + disp_data_fig3[:, pair_idx] = np.array(mean_displacements_for_this_pair) + else: # Should not happen if previous checks pass + print(f"Plot3: Inner loop displacement calculation mismatch for {plot_label}") + else: - print(f"Plot3: Missing data for {ts_method_name} or {ase_method_name}.") - disp_plot_data_fig3[plot_label] = np.nan - - -if disp_plot_data_fig3: - disp_methods_fig3 = list(disp_plot_data_fig3.keys()) - disp_values_fig3 = list(disp_plot_data_fig3.values()) - - fig3, ax3 = plt.subplots(figsize=(12, 8)) - ax3.bar(disp_methods_fig3, disp_values_fig3, color='mediumseagreen') - ax3.set_ylabel('Avg. Mean Atomic Displacement (Å) to ASE Counterpart') - ax3.set_xlabel('Comparison Pair') - ax3.set_title('Mean Displacement of Torch-Sim Methods to ASE Counterparts') - plt.xticks(rotation=45, ha="right") - plt.tight_layout() -else: - print("No displacement data to plot for Figure 3.") + print(f"Plot3: Missing data for methods in pair: {plot_label}.") + # Data remains NaN + +fig3, ax3 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend +x_fig3 = np.arange(num_structures_plot) +width_fig3 = 0.8 / num_comparison_pairs_plot3 + +for i in range(num_comparison_pairs_plot3): + ax3.bar(x_fig3 - 0.4 + (i + 0.5) * width_fig3, disp_data_fig3[:, i], width_fig3, label=legend_labels_fig3[i]) + +ax3.set_ylabel('Mean Atomic Displacement (Å) to ASE Counterpart') +ax3.set_xlabel('Structure') +ax3.set_title('Mean Displacement of Torch-Sim Methods to ASE Counterparts (per Structure)') +ax3.set_xticks(x_fig3) +ax3.set_xticklabels(structure_names, rotation=45, ha="right") +ax3.legend(title="Comparison Pair", bbox_to_anchor=(1.02, 1), loc='upper left') # Adjusted legend +ax3.grid(True, axis='y', linestyle='--', alpha=0.7) +plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend + plt.show() From bf29852b235d24e537a9df1c986f7cc56b81b045 Mon Sep 17 00:00:00 2001 From: Myles Stapelberg Date: Sat, 17 May 2025 17:32:36 -0400 Subject: [PATCH 3/7] fixed labels for ASE Native --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index b08e591a..5166d804 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -17,6 +17,7 @@ from ase.build import bulk from ase.optimize import FIRE as ASEFIRE from ase.filters import FrechetCellFilter +from ase.cell import Cell from mace.calculators.foundations_models import mace_mp from mace.calculators.foundations_models import mace_mp as mace_mp_calculator_for_ase import matplotlib.pyplot as plt @@ -138,6 +139,13 @@ def run_optimization( ) start_time = time.perf_counter() + print("Initial cell parameters (Torch-Sim):") + for k_idx in range(initial_state.n_batches): + cell_tensor_k = initial_state.cell[k_idx].cpu().numpy() + ase_cell_k = Cell(cell_tensor_k) + params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) + print(f" Structure {k_idx+1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]") + if ts_use_frechet: # Uses frechet_cell_fire for combined cell and position optimization init_fn_opt, update_fn_opt = frechet_cell_fire( @@ -217,6 +225,17 @@ def run_optimization( final_states_list = batcher.restore_original_order(all_converged_states) final_state_concatenated = ts.concatenate_states(final_states_list) + + if final_state_concatenated is not None and hasattr(final_state_concatenated, 'cell'): + print("Final cell parameters (Torch-Sim):") + for k_idx in range(final_state_concatenated.n_batches): + cell_tensor_k = final_state_concatenated.cell[k_idx].cpu().numpy() + ase_cell_k = Cell(cell_tensor_k) + params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) + print(f" Structure {k_idx+1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]") + else: + print("Final cell parameters (Torch-Sim): Not available (final_state_concatenated is None or has no cell).") + end_time = time.perf_counter() print( f"Finished Torch-Sim ({ts_md_flavor}, frechet={ts_use_frechet}) in " @@ -240,6 +259,10 @@ def run_optimization( print(f"Optimizing structure {i+1}/{num_structures} with ASE...") ase_atoms_orig = ts.io.state_to_atoms(single_sim_state)[0] + initial_cell_ase = ase_atoms_orig.get_cell() + initial_params_str = ", ".join([f"{p:.2f}" for p in initial_cell_ase.cellpar()]) + print(f" Initial cell (ASE Structure {i+1}): Volume={initial_cell_ase.volume:.2f} ų, Params=[{initial_params_str}]") + ase_calc_instance = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, device=device, @@ -269,7 +292,12 @@ def run_optimization( print(f"ASE optimization failed for structure {i+1}: {e}") convergence_steps_list.append(-1) - final_ase_atoms_list.append(optim_target_atoms.atoms if ase_use_frechet_filter else ase_atoms_orig) + final_ats_for_print = optim_target_atoms.atoms if ase_use_frechet_filter else ase_atoms_orig + final_cell_ase = final_ats_for_print.get_cell() + final_params_str = ", ".join([f"{p:.2f}" for p in final_cell_ase.cellpar()]) + print(f" Final cell (ASE Structure {i+1}): Volume={final_cell_ase.volume:.2f} ų, Params=[{final_params_str}]") + + final_ase_atoms_list.append(final_ats_for_print) # Convert list of final ASE atoms objects back to a base SimState first # to easily get positions, cell, etc. @@ -388,7 +416,7 @@ def run_optimization( "type": "torch_sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": True, }, { - "name": "ASE FIRE (Native, CellOpt)", # Will optimize cell if stress is available + "name": "ASE FIRE (Native, PosOnly)", # Corrected name: Only optimizes positions without a cell filter "type": "ase", "ase_use_frechet_filter": False, }, { @@ -439,10 +467,10 @@ def run_optimization( # Mean Displacement Comparisons comparison_pairs = [ - ("torch-sim ASE-FIRE (PosOnly)", "ASE FIRE (Native, CellOpt)"), # Note: one is pos-only, other cell-opt + ("torch-sim ASE-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), ("torch-sim ASE-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), ("torch-sim VV-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), - ("torch-sim VV-FIRE (PosOnly)", "torch-sim ASE-FIRE (PosOnly)"), # Original comparison + ("torch-sim VV-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), ] for name1, name2 in comparison_pairs: @@ -547,7 +575,7 @@ def run_optimization( # --- Plot 2: Average Final Energy Difference from Baselines --- -baseline_ase_pos_only = "ASE FIRE (Native, CellOpt)" +baseline_ase_pos_only = "ASE FIRE (Native, PosOnly)" baseline_ase_frechet = "ASE FIRE (Native Frechet Filter, CellOpt)" avg_energy_diffs_fig2 = [] plot_names_fig2 = [] From 4324670681752a413e23a1231afa1f8bac75e5ad Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 20 May 2025 21:34:50 -0400 Subject: [PATCH 4/7] clean: remove Atoms comment --- torch_sim/optimizers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index a4bbfa3c..0d94b4a5 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -1524,7 +1524,6 @@ def _ase_fire_step( # noqa: C901, PLR0915 dr_cell = cell_dt * state.cell_velocities # 6. Clamp to max_step - # Atoms dr_norm_atom = torch.norm(dr_atom, dim=1, keepdim=True) mask_atom_max_step = dr_norm_atom > max_step dr_atom = torch.where( From 62293175a07a430f0f3166ba1c2c779bf02bece9 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 21 May 2025 13:10:12 -0400 Subject: [PATCH 5/7] split run_optimization into run_optimization_ts and run_optimization_ase https://github.com/Radical-AI/torch-sim/pull/200#discussion_r2094218636 --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 938 +++++++++++------- 1 file changed, 559 insertions(+), 379 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 5166d804..8ae4b78b 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -12,19 +12,19 @@ import time from typing import Literal +import matplotlib.pyplot as plt import numpy as np import torch from ase.build import bulk -from ase.optimize import FIRE as ASEFIRE -from ase.filters import FrechetCellFilter from ase.cell import Cell +from ase.filters import FrechetCellFilter +from ase.optimize import FIRE as ASEFIRE from mace.calculators.foundations_models import mace_mp from mace.calculators.foundations_models import mace_mp as mace_mp_calculator_for_ase -import matplotlib.pyplot as plt import torch_sim as ts from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.optimizers import fire, frechet_cell_fire, GDState +from torch_sim.optimizers import GDState, fire, frechet_cell_fire from torch_sim.state import SimState @@ -44,7 +44,8 @@ # Number of steps to run max_iterations = 10 if os.getenv("CI") else 500 supercell_scale = (1, 1, 1) if os.getenv("CI") else (3, 2, 2) -ase_max_optimizer_steps = max_iterations * 10 # Max steps for each individual ASE optimization run +# Max steps for each individual ASE optimization run +ase_max_optimizer_steps = max_iterations * 10 # Set random seed for reproducibility rng = np.random.default_rng(seed=0) @@ -119,279 +120,314 @@ initial_energies = model(state)["energy"] -def run_optimization( +def run_optimization_ts( + *, initial_state: SimState, - optimizer_type: Literal["torch_sim", "ase"], - # For torch_sim: - ts_md_flavor: Literal["vv_fire", "ase_fire"] | None = None, - ts_use_frechet: bool = False, # To decide between fire() and frechet_cell_fire() - # For ASE: - ase_use_frechet_filter: bool = False, - # Common: - force_tol: float = 0.05, -) -> tuple[torch.Tensor, SimState]: - """Runs optimization and returns convergence steps and final state.""" - if optimizer_type == "torch_sim": - assert ts_md_flavor is not None, "ts_md_flavor must be provided for torch_sim" + ts_md_flavor: Literal["vv_fire", "ase_fire"], + ts_use_frechet: bool, + force_tol: float, + max_iterations_ts: int, +) -> tuple[torch.Tensor, SimState | None]: + """Runs Torch-Sim optimization and returns convergence steps and final state.""" + print( + f"\n--- Running Torch-Sim optimization: flavor={ts_md_flavor}, " + f"frechet_cell_opt={ts_use_frechet}, force_tol={force_tol} ---" + ) + start_time = time.perf_counter() + + print("Initial cell parameters (Torch-Sim):") + for k_idx in range(initial_state.n_batches): + cell_tensor_k = initial_state.cell[k_idx].cpu().numpy() + ase_cell_k = Cell(cell_tensor_k) + params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) print( - f"\n--- Running Torch-Sim optimization: flavor={ts_md_flavor}, " - f"frechet_cell_opt={ts_use_frechet}, force_tol={force_tol} ---" + f" Structure {k_idx + 1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]" ) - start_time = time.perf_counter() - print("Initial cell parameters (Torch-Sim):") - for k_idx in range(initial_state.n_batches): - cell_tensor_k = initial_state.cell[k_idx].cpu().numpy() - ase_cell_k = Cell(cell_tensor_k) - params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) - print(f" Structure {k_idx+1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]") + if ts_use_frechet: + # Uses frechet_cell_fire for combined cell and position optimization + init_fn_opt, update_fn_opt = frechet_cell_fire( + model=model, md_flavor=ts_md_flavor + ) + else: + # Uses fire for position-only optimization + init_fn_opt, update_fn_opt = fire(model=model, md_flavor=ts_md_flavor) - if ts_use_frechet: - # Uses frechet_cell_fire for combined cell and position optimization - init_fn_opt, update_fn_opt = frechet_cell_fire( - model=model, md_flavor=ts_md_flavor - ) - else: - # Uses fire for position-only optimization - init_fn_opt, update_fn_opt = fire(model=model, md_flavor=ts_md_flavor) + opt_state = init_fn_opt(initial_state.clone()) - opt_state = init_fn_opt(initial_state.clone()) + batcher = ts.InFlightAutoBatcher( + model=model, # The MaceModel wrapper + memory_scales_with="n_atoms", + max_memory_scaler=1000, + max_iterations=max_iterations_ts, # Use the passed max_iterations + return_indices=True, + ) + batcher.load_states(opt_state) - batcher = ts.InFlightAutoBatcher( - model=model, # The MaceModel wrapper - memory_scales_with="n_atoms", - max_memory_scaler=1000, - max_iterations=max_iterations, - return_indices=True, - ) - batcher.load_states(opt_state) + total_structures = opt_state.n_batches + convergence_steps = torch.full( + (total_structures,), -1, dtype=torch.long, device=device + ) + convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) + converged_tensor_global = torch.zeros( + total_structures, dtype=torch.bool, device=device + ) + global_step = 0 + all_converged_states = [] + convergence_tensor_for_batcher = None + last_active_state = opt_state + + while True: + result = batcher.next_batch(last_active_state, convergence_tensor_for_batcher) + opt_state, converged_states_from_batcher, current_indices_list = result + all_converged_states.extend(converged_states_from_batcher) + + if opt_state is None: + print("All structures converged or batcher reached max iterations.") + break - total_structures = opt_state.n_batches - convergence_steps = torch.full( - (total_structures,), -1, dtype=torch.long, device=device - ) - convergence_fn = ts.generate_force_convergence_fn(force_tol=force_tol) - converged_tensor_global = torch.zeros( - total_structures, dtype=torch.bool, device=device - ) - global_step = 0 - all_converged_states = [] - convergence_tensor_for_batcher = None last_active_state = opt_state + current_indices = torch.tensor( + current_indices_list, dtype=torch.long, device=device + ) + + steps_this_round = 10 + for _ in range(steps_this_round): + opt_state = update_fn_opt(opt_state) + global_step += steps_this_round - while True: - result = batcher.next_batch( - last_active_state, convergence_tensor_for_batcher + convergence_tensor_for_batcher = convergence_fn(opt_state, None) + newly_converged_mask_local = convergence_tensor_for_batcher & ( + convergence_steps[current_indices] == -1 + ) + converged_indices_global = current_indices[newly_converged_mask_local] + + if converged_indices_global.numel() > 0: + convergence_steps[converged_indices_global] = global_step + converged_tensor_global[converged_indices_global] = True + total_converged_frac = converged_tensor_global.sum().item() / total_structures + print( + f"{global_step=}: Converged indices {converged_indices_global.tolist()}, " + f"Total converged: {total_converged_frac:.2%}" ) - opt_state, converged_states_from_batcher, current_indices_list = result - all_converged_states.extend(converged_states_from_batcher) - - if opt_state is None: - print("All structures converged or batcher reached max iterations.") - break - - last_active_state = opt_state - current_indices = torch.tensor( - current_indices_list, dtype=torch.long, device=device + + if global_step % 50 == 0: + total_converged_frac = converged_tensor_global.sum().item() / total_structures + active_structures = opt_state.n_batches if opt_state else 0 + print( + f"{global_step=}: Active structures: {active_structures}, " + f"Total converged: {total_converged_frac:.2%}" ) - steps_this_round = 10 - for _ in range(steps_this_round): - opt_state = update_fn_opt(opt_state) - global_step += steps_this_round + final_states_list = batcher.restore_original_order(all_converged_states) + final_state_concatenated = ts.concatenate_states(final_states_list) - convergence_tensor_for_batcher = convergence_fn(opt_state, None) - newly_converged_mask_local = convergence_tensor_for_batcher & ( - convergence_steps[current_indices] == -1 + if final_state_concatenated is not None and hasattr(final_state_concatenated, "cell"): + print("Final cell parameters (Torch-Sim):") + for k_idx in range(final_state_concatenated.n_batches): + cell_tensor_k = final_state_concatenated.cell[k_idx].cpu().numpy() + ase_cell_k = Cell(cell_tensor_k) + params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) + print( + f" Structure {k_idx + 1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]" ) - converged_indices_global = current_indices[newly_converged_mask_local] + else: + print( + "Final cell parameters (Torch-Sim): Not available (final_state_concatenated is None or has no cell)." + ) - if converged_indices_global.numel() > 0: - convergence_steps[converged_indices_global] = global_step - converged_tensor_global[converged_indices_global] = True - total_converged_frac = converged_tensor_global.sum().item() / total_structures - print( - f"{global_step=}: Converged indices {converged_indices_global.tolist()}, " - f"Total converged: {total_converged_frac:.2%}" - ) + end_time = time.perf_counter() + print( + f"Finished Torch-Sim ({ts_md_flavor}, frechet={ts_use_frechet}) in " + f"{end_time - start_time:.2f} seconds." + ) + return convergence_steps, final_state_concatenated + + +def run_optimization_ase( + *, + initial_state: SimState, + ase_use_frechet_filter: bool, + force_tol: float, + max_steps_ase: int, +) -> tuple[torch.Tensor, GDState | None]: + """Runs ASE optimization and returns convergence steps and final state.""" + print( + f"\n--- Running ASE optimization: frechet_filter={ase_use_frechet_filter}, " + f"force_tol={force_tol} ---" + ) + start_time = time.perf_counter() + + individual_initial_states = initial_state.split() + num_structures = len(individual_initial_states) + final_ase_atoms_list = [] + convergence_steps_list = [] + + for i, single_sim_state in enumerate(individual_initial_states): + print(f"Optimizing structure {i + 1}/{num_structures} with ASE...") + ase_atoms_orig = ts.io.state_to_atoms(single_sim_state)[0] - if global_step % 50 == 0: - total_converged_frac = converged_tensor_global.sum().item() / total_structures - active_structures = opt_state.n_batches if opt_state else 0 + initial_cell_ase = ase_atoms_orig.get_cell() + initial_params_str = ", ".join([f"{p:.2f}" for p in initial_cell_ase.cellpar()]) + print( + f" Initial cell (ASE Structure {i + 1}): Volume={initial_cell_ase.volume:.2f} ų, Params=[{initial_params_str}]" + ) + + ase_calc_instance = mace_mp_calculator_for_ase( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=str(dtype).split(".")[-1], + ) + ase_atoms_orig.calc = ase_calc_instance + + optim_target_atoms = ase_atoms_orig + if ase_use_frechet_filter: + print(f"Applying FrechetCellFilter to structure {i + 1}") + optim_target_atoms = FrechetCellFilter(ase_atoms_orig) + + dyn = ASEFIRE(optim_target_atoms, trajectory=None, logfile=None) + + try: + dyn.run(fmax=force_tol, steps=max_steps_ase) # Use passed max_steps_ase + if dyn.converged(): + convergence_steps_list.append(dyn.nsteps) + print(f"ASE structure {i + 1} converged in {dyn.nsteps} steps.") + else: print( - f"{global_step=}: Active structures: {active_structures}, " - f"Total converged: {total_converged_frac:.2%}" + f"ASE optimization for structure {i + 1} did not converge within " + f"{max_steps_ase} steps. Steps taken: {dyn.nsteps}." ) - - final_states_list = batcher.restore_original_order(all_converged_states) - final_state_concatenated = ts.concatenate_states(final_states_list) - - if final_state_concatenated is not None and hasattr(final_state_concatenated, 'cell'): - print("Final cell parameters (Torch-Sim):") - for k_idx in range(final_state_concatenated.n_batches): - cell_tensor_k = final_state_concatenated.cell[k_idx].cpu().numpy() - ase_cell_k = Cell(cell_tensor_k) - params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) - print(f" Structure {k_idx+1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]") - else: - print("Final cell parameters (Torch-Sim): Not available (final_state_concatenated is None or has no cell).") + convergence_steps_list.append(-1) + except Exception as e: + print(f"ASE optimization failed for structure {i + 1}: {e}") + convergence_steps_list.append(-1) - end_time = time.perf_counter() + final_ats_for_print = ( + optim_target_atoms.atoms if ase_use_frechet_filter else ase_atoms_orig + ) + final_cell_ase = final_ats_for_print.get_cell() + final_params_str = ", ".join([f"{p:.2f}" for p in final_cell_ase.cellpar()]) print( - f"Finished Torch-Sim ({ts_md_flavor}, frechet={ts_use_frechet}) in " - f"{end_time - start_time:.2f} seconds." + f" Final cell (ASE Structure {i + 1}): Volume={final_cell_ase.volume:.2f} ų, Params=[{final_params_str}]" ) - return convergence_steps, final_state_concatenated - elif optimizer_type == "ase": - print( - f"\n--- Running ASE optimization: frechet_filter={ase_use_frechet_filter}, " - f"force_tol={force_tol} ---" + final_ase_atoms_list.append(final_ats_for_print) + + # Convert list of final ASE atoms objects back to a base SimState first + # to easily get positions, cell, etc. + # However, ts.io.atoms_to_state might not preserve all attributes needed for GDState directly. + # It's better to extract all required components directly from final_ase_atoms_list. + + all_positions = [] + all_masses = [] + all_atomic_numbers = [] + all_cells = [] + all_batches_for_gd = [] + final_energies_ase = [] + final_forces_ase_tensors = [] # List to store force tensors + + current_atom_offset = 0 + for batch_idx, ats_final in enumerate(final_ase_atoms_list): + all_positions.append( + torch.tensor(ats_final.get_positions(), device=device, dtype=dtype) + ) + all_masses.append( + torch.tensor(ats_final.get_masses(), device=device, dtype=dtype) + ) + all_atomic_numbers.append( + torch.tensor(ats_final.get_atomic_numbers(), device=device, dtype=torch.long) + ) + # ASE cell is row-vector, SimState expects column-vector + all_cells.append( + torch.tensor(ats_final.get_cell().array.T, device=device, dtype=dtype) ) - start_time = time.perf_counter() - - individual_initial_states = initial_state.split() - num_structures = len(individual_initial_states) - final_ase_atoms_list = [] - convergence_steps_list = [] - - for i, single_sim_state in enumerate(individual_initial_states): - print(f"Optimizing structure {i+1}/{num_structures} with ASE...") - ase_atoms_orig = ts.io.state_to_atoms(single_sim_state)[0] - - initial_cell_ase = ase_atoms_orig.get_cell() - initial_params_str = ", ".join([f"{p:.2f}" for p in initial_cell_ase.cellpar()]) - print(f" Initial cell (ASE Structure {i+1}): Volume={initial_cell_ase.volume:.2f} ų, Params=[{initial_params_str}]") - - ase_calc_instance = mace_mp_calculator_for_ase( - model=MaceUrls.mace_mpa_medium, - device=device, - default_dtype=str(dtype).split('.')[-1], - ) - ase_atoms_orig.calc = ase_calc_instance - - optim_target_atoms = ase_atoms_orig - if ase_use_frechet_filter: - print(f"Applying FrechetCellFilter to structure {i+1}") - optim_target_atoms = FrechetCellFilter(ase_atoms_orig) - - dyn = ASEFIRE(optim_target_atoms, trajectory=None, logfile=None) - - try: - dyn.run(fmax=force_tol, steps=ase_max_optimizer_steps) - if dyn.converged(): - convergence_steps_list.append(dyn.nsteps) - print(f"ASE structure {i+1} converged in {dyn.nsteps} steps.") - else: - print( - f"ASE optimization for structure {i+1} did not converge within " - f"{ase_max_optimizer_steps} steps. Steps taken: {dyn.nsteps}." - ) - convergence_steps_list.append(-1) - except Exception as e: - print(f"ASE optimization failed for structure {i+1}: {e}") - convergence_steps_list.append(-1) - final_ats_for_print = optim_target_atoms.atoms if ase_use_frechet_filter else ase_atoms_orig - final_cell_ase = final_ats_for_print.get_cell() - final_params_str = ", ".join([f"{p:.2f}" for p in final_cell_ase.cellpar()]) - print(f" Final cell (ASE Structure {i+1}): Volume={final_cell_ase.volume:.2f} ų, Params=[{final_params_str}]") - - final_ase_atoms_list.append(final_ats_for_print) - - # Convert list of final ASE atoms objects back to a base SimState first - # to easily get positions, cell, etc. - # However, ts.io.atoms_to_state might not preserve all attributes needed for GDState directly. - # It's better to extract all required components directly from final_ase_atoms_list. - - all_positions = [] - all_masses = [] - all_atomic_numbers = [] - all_cells = [] - all_batches_for_gd = [] - final_energies_ase = [] - final_forces_ase_tensors = [] # List to store force tensors - - current_atom_offset = 0 - for batch_idx, ats_final in enumerate(final_ase_atoms_list): - all_positions.append(torch.tensor(ats_final.get_positions(), device=device, dtype=dtype)) - all_masses.append(torch.tensor(ats_final.get_masses(), device=device, dtype=dtype)) - all_atomic_numbers.append(torch.tensor(ats_final.get_atomic_numbers(), device=device, dtype=torch.long)) - # ASE cell is row-vector, SimState expects column-vector - all_cells.append(torch.tensor(ats_final.get_cell().array.T, device=device, dtype=dtype)) - - num_atoms_in_current = len(ats_final) - all_batches_for_gd.append(torch.full((num_atoms_in_current,), batch_idx, device=device, dtype=torch.long)) - current_atom_offset += num_atoms_in_current - - try: - if ats_final.calc is None: - print(f"Re-attaching ASE calculator for final energy/forces for structure {batch_idx}.") - temp_calc = mace_mp_calculator_for_ase( - model=MaceUrls.mace_mpa_medium, device=device, default_dtype=str(dtype).split('.')[-1] - ) - ats_final.calc = temp_calc - final_energies_ase.append(ats_final.get_potential_energy()) - final_forces_ase_tensors.append(torch.tensor(ats_final.get_forces(), device=device, dtype=dtype)) - except Exception as e: - print(f"Could not get final energy/forces for an ASE structure {batch_idx}: {e}") - final_energies_ase.append(float('nan')) - # Append a zero tensor of appropriate shape if forces fail, or handle error - # For GDState, forces are required. If any structure fails, GDState creation might fail. - # We need to ensure all_positions, etc. are also correctly populated even on failure. - # For now, let's assume if energy fails, forces might also, and GDState might be problematic. - # A robust solution would be to skip failed structures or return None. - # For now, let's make forces a zero tensor of expected shape if it fails. - if all_positions and len(all_positions[-1]) > 0: - final_forces_ase_tensors.append(torch.zeros_like(all_positions[-1])) - else: # Cannot determine shape, this path is problematic - final_forces_ase_tensors.append(torch.empty((0,3), device=device, dtype=dtype)) - - - if not all_positions: # If all optimizations failed early - print("Warning: No successful ASE structures to form GDState.") - return torch.tensor(convergence_steps_list, dtype=torch.long, device=device), None - - - # Concatenate all parts - concatenated_positions = torch.cat(all_positions, dim=0) - concatenated_masses = torch.cat(all_masses, dim=0) - concatenated_atomic_numbers = torch.cat(all_atomic_numbers, dim=0) - concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (N_batch, 3, 3) - concatenated_batch_indices = torch.cat(all_batches_for_gd, dim=0) - - concatenated_energies = torch.tensor(final_energies_ase, device=device, dtype=dtype) - concatenated_forces = torch.cat(final_forces_ase_tensors, dim=0) - - # Check for NaN energies which might cause issues - if torch.isnan(concatenated_energies).any(): - print("Warning: NaN values found in final ASE energies. GDState energy tensor will contain NaNs.") - # Consider replacing NaNs if GDState or subsequent ops can't handle them: - # concatenated_energies = torch.nan_to_num(concatenated_energies, nan=0.0) # Example replacement - - # Create GDState instance - # pbc is global, taken from initial_state - final_state_as_gd = GDState( - positions=concatenated_positions, - masses=concatenated_masses, - cell=concatenated_cells, - pbc=initial_state.pbc, # Assuming pbc is constant and global - atomic_numbers=concatenated_atomic_numbers, - batch=concatenated_batch_indices, - energy=concatenated_energies, - forces=concatenated_forces, + num_atoms_in_current = len(ats_final) + all_batches_for_gd.append( + torch.full( + (num_atoms_in_current,), batch_idx, device=device, dtype=torch.long + ) ) - - convergence_steps = torch.tensor(convergence_steps_list, dtype=torch.long, device=device) + current_atom_offset += num_atoms_in_current - end_time = time.perf_counter() + try: + if ats_final.calc is None: + print( + f"Re-attaching ASE calculator for final energy/forces for structure {batch_idx}." + ) + temp_calc = mace_mp_calculator_for_ase( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=str(dtype).split(".")[-1], + ) + ats_final.calc = temp_calc + final_energies_ase.append(ats_final.get_potential_energy()) + final_forces_ase_tensors.append( + torch.tensor(ats_final.get_forces(), device=device, dtype=dtype) + ) + except Exception as e: + print( + f"Could not get final energy/forces for an ASE structure {batch_idx}: {e}" + ) + final_energies_ase.append(float("nan")) + # Append a zero tensor of appropriate shape if forces fail, or handle error + # For GDState, forces are required. If any structure fails, GDState creation might fail. + # We need to ensure all_positions, etc. are also correctly populated even on failure. + # For now, let's assume if energy fails, forces might also, and GDState might be problematic. + # A robust solution would be to skip failed structures or return None. + # For now, let's make forces a zero tensor of expected shape if it fails. + if all_positions and len(all_positions[-1]) > 0: + final_forces_ase_tensors.append(torch.zeros_like(all_positions[-1])) + else: # Cannot determine shape, this path is problematic + final_forces_ase_tensors.append( + torch.empty((0, 3), device=device, dtype=dtype) + ) + + if not all_positions: # If all optimizations failed early + print("Warning: No successful ASE structures to form GDState.") + return torch.tensor(convergence_steps_list, dtype=torch.long, device=device), None + + # Concatenate all parts + concatenated_positions = torch.cat(all_positions, dim=0) + concatenated_masses = torch.cat(all_masses, dim=0) + concatenated_atomic_numbers = torch.cat(all_atomic_numbers, dim=0) + concatenated_cells = torch.stack(all_cells, dim=0) # Cells are (N_batch, 3, 3) + concatenated_batch_indices = torch.cat(all_batches_for_gd, dim=0) + + concatenated_energies = torch.tensor(final_energies_ase, device=device, dtype=dtype) + concatenated_forces = torch.cat(final_forces_ase_tensors, dim=0) + + # Check for NaN energies which might cause issues + if torch.isnan(concatenated_energies).any(): print( - f"Finished ASE optimization (frechet_filter={ase_use_frechet_filter}) " - f"in {end_time - start_time:.2f} seconds." + "Warning: NaN values found in final ASE energies. GDState energy tensor will contain NaNs." ) - return convergence_steps, final_state_as_gd - else: - raise ValueError(f"Unknown optimizer_type: {optimizer_type}") + # Consider replacing NaNs if GDState or subsequent ops can't handle them: + # concatenated_energies = torch.nan_to_num(concatenated_energies, nan=0.0) # Example replacement + + # Create GDState instance + # pbc is global, taken from initial_state + final_state_as_gd = GDState( + positions=concatenated_positions, + masses=concatenated_masses, + cell=concatenated_cells, + pbc=initial_state.pbc, # Assuming pbc is constant and global + atomic_numbers=concatenated_atomic_numbers, + batch=concatenated_batch_indices, + energy=concatenated_energies, + forces=concatenated_forces, + ) + + convergence_steps = torch.tensor( + convergence_steps_list, dtype=torch.long, device=device + ) + + end_time = time.perf_counter() + print( + f"Finished ASE optimization (frechet_filter={ase_use_frechet_filter}) " + f"in {end_time - start_time:.2f} seconds." + ) + return convergence_steps, final_state_as_gd # --- Main Script --- @@ -401,113 +437,150 @@ def run_optimization( configs_to_run = [ { "name": "torch-sim VV-FIRE (PosOnly)", - "type": "torch_sim", "ts_md_flavor": "vv_fire", "ts_use_frechet": False, + "type": "torch_sim", + "ts_md_flavor": "vv_fire", + "ts_use_frechet": False, }, { "name": "torch-sim ASE-FIRE (PosOnly)", - "type": "torch_sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": False, + "type": "torch_sim", + "ts_md_flavor": "ase_fire", + "ts_use_frechet": False, }, { "name": "torch-sim VV-FIRE (Frechet Cell)", - "type": "torch_sim", "ts_md_flavor": "vv_fire", "ts_use_frechet": True, + "type": "torch_sim", + "ts_md_flavor": "vv_fire", + "ts_use_frechet": True, }, { "name": "torch-sim ASE-FIRE (Frechet Cell)", - "type": "torch_sim", "ts_md_flavor": "ase_fire", "ts_use_frechet": True, + "type": "torch_sim", + "ts_md_flavor": "ase_fire", + "ts_use_frechet": True, }, { - "name": "ASE FIRE (Native, PosOnly)", # Corrected name: Only optimizes positions without a cell filter - "type": "ase", "ase_use_frechet_filter": False, + "name": "ASE FIRE (Native, PosOnly)", # Corrected name: Only optimizes positions without a cell filter + "type": "ase", + "ase_use_frechet_filter": False, }, { "name": "ASE FIRE (Native Frechet Filter, CellOpt)", - "type": "ase", "ase_use_frechet_filter": True, + "type": "ase", + "ase_use_frechet_filter": True, }, ] -results_all = {} +all_results = {} for config_run in configs_to_run: print(f"\n\nStarting configuration: {config_run['name']}") # Get relevant params, providing defaults where necessary for the run_optimization call optimizer_type_val = config_run["type"] - ts_md_flavor_val = config_run.get("ts_md_flavor") # Will be None for ASE type, handled by assert + # Will be None for ASE type, handled by assert + ts_md_flavor_val = config_run.get("ts_md_flavor") ts_use_frechet_val = config_run.get("ts_use_frechet", False) ase_use_frechet_filter_val = config_run.get("ase_use_frechet_filter", False) - steps, final_state_opt = run_optimization( - initial_state=state.clone(), # Use a fresh clone for each run - optimizer_type=optimizer_type_val, - ts_md_flavor=ts_md_flavor_val, - ts_use_frechet=ts_use_frechet_val, - ase_use_frechet_filter=ase_use_frechet_filter_val, - force_tol=force_tol, - ) - results_all[config_run["name"]] = {"steps": steps, "final_state": final_state_opt} + steps: torch.Tensor | None = None + final_state_opt: SimState | GDState | None = None + + if optimizer_type_val == "torch_sim": + assert ts_md_flavor_val is not None, "ts_md_flavor must be provided for torch_sim" + steps, final_state_opt = run_optimization_ts( + initial_state=state.clone(), # Use a fresh clone for each run + ts_md_flavor=ts_md_flavor_val, + ts_use_frechet=ts_use_frechet_val, + force_tol=force_tol, + max_iterations_ts=max_iterations, # Pass the global max_iterations + ) + elif optimizer_type_val == "ase": + steps, final_state_opt = run_optimization_ase( + initial_state=state.clone(), # Use a fresh clone for each run + ase_use_frechet_filter=ase_use_frechet_filter_val, + force_tol=force_tol, + max_steps_ase=ase_max_optimizer_steps, # Pass the global ase_max_optimizer_steps + ) + else: + raise ValueError(f"Unknown optimizer_type: {optimizer_type_val}") + + all_results[config_run["name"]] = {"steps": steps, "final_state": final_state_opt} print("\n\n--- Overall Comparison ---") print(f"{force_tol=:.2f} eV/Å") print(f"Initial energies: {[f'{e.item():.3f}' for e in initial_energies]} eV") -for name, result_data in results_all.items(): +for name, result_data in all_results.items(): final_state_res = result_data["final_state"] steps_res = result_data["steps"] print(f"\nResults for: {name}") - if final_state_res is not None and hasattr(final_state_res, 'energy') and final_state_res.energy is not None: - energy_str = [f'{e.item():.3f}' for e in final_state_res.energy] + if ( + final_state_res is not None + and hasattr(final_state_res, "energy") + and final_state_res.energy is not None + ): + energy_str = [f"{e.item():.3f}" for e in final_state_res.energy] print(f" Final energies: {energy_str} eV") else: - print(f" Final energies: Not available or state is None") + print(" Final energies: Not available or state is None") print(f" Convergence steps: {steps_res.tolist()}") - + not_converged_indices = torch.where(steps_res == -1)[0].tolist() if not_converged_indices: print(f" Did not converge for structure indices: {not_converged_indices}") # Mean Displacement Comparisons comparison_pairs = [ - ("torch-sim ASE-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), + ("torch-sim ASE-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), ("torch-sim ASE-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), ("torch-sim VV-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), - ("torch-sim VV-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), + ("torch-sim VV-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), ] for name1, name2 in comparison_pairs: - if name1 in results_all and name2 in results_all: - state1 = results_all[name1]["final_state"] - state2 = results_all[name2]["final_state"] + if name1 in all_results and name2 in all_results: + state1 = all_results[name1]["final_state"] + state2 = all_results[name2]["final_state"] if state1 is None or state2 is None: print(f"\nCannot compare {name1} and {name2}, one or both states are None.") continue - + state1_list = state1.split() - + state2_list = state2.split() - + if len(state1_list) != len(state2_list): - print(f"\nCannot compare {name1} and {name2}, different number of structures.") + print( + f"\nCannot compare {name1} and {name2}, different number of structures." + ) continue mean_displacements = [] for s1, s2 in zip(state1_list, state2_list, strict=True): - if s1.n_atoms == 0 or s2.n_atoms == 0 : # Handle empty states if they occur - mean_displacements.append(float('nan')) + if s1.n_atoms == 0 or s2.n_atoms == 0: # Handle empty states if they occur + mean_displacements.append(float("nan")) continue pos1_centered = s1.positions - s1.positions.mean(dim=0, keepdim=True) pos2_centered = s2.positions - s2.positions.mean(dim=0, keepdim=True) if pos1_centered.shape != pos2_centered.shape: - print(f"Warning: Shape mismatch for {name1} vs {name2} in structure. Skipping displacement calc.") - mean_displacements.append(float('nan')) - continue + print( + f"Warning: Shape mismatch for {name1} vs {name2} in structure. Skipping displacement calc." + ) + mean_displacements.append(float("nan")) + continue displacement = torch.norm(pos1_centered - pos2_centered, dim=1) mean_disp = torch.mean(displacement).item() mean_displacements.append(mean_disp) - - print(f"\nMean Disp ({name1} vs {name2}): {[f'{d:.4f}' for d in mean_displacements]} Å") + + print( + f"\nMean Disp ({name1} vs {name2}): {[f'{d:.4f}' for d in mean_displacements]} Å" + ) else: - print(f"\nSkipping displacement comparison for ({name1} vs {name2}), one or both results missing.") + print( + f"\nSkipping displacement comparison for ({name1} vs {name2}), one or both results missing." + ) # --- Plotting Results --- @@ -515,63 +588,94 @@ def run_optimization( # Names for the structures for plotting labels original_structure_formulas = [ats.get_chemical_formula() for ats in atoms_list] # Make them more concise if needed: -structure_names = ["Si_bulk", "Cu_bulk", "Fe_bulk", "Si_vac", "Cu_vac", "Fe_vac"] # Updated for 6 structures +structure_names = [ + "Si_bulk", + "Cu_bulk", + "Fe_bulk", + "Si_vac", + "Cu_vac", + "Fe_vac", +] # Updated for 6 structures if len(structure_names) != len(atoms_list): - print(f"Warning: Mismatch between custom structure_names ({len(structure_names)}) and atoms_list ({len(atoms_list)}). Using custom names.") + print( + f"Warning: Mismatch between custom structure_names ({len(structure_names)}) and atoms_list ({len(atoms_list)}). Using custom names." + ) num_structures_plot = len(structure_names) # --- Plot 1: Convergence Steps (Multi-bar per structure) --- -plot_methods_fig1 = list(results_all.keys()) +plot_methods_fig1 = list(all_results.keys()) num_methods_fig1 = len(plot_methods_fig1) # Initialize with NaNs, so if a method fails completely, its bars are missing or clearly marked -steps_data_fig1 = np.full((num_structures_plot, num_methods_fig1), np.nan) +steps_data_fig1 = np.full((num_structures_plot, num_methods_fig1), np.nan) for method_idx, method_name in enumerate(plot_methods_fig1): - result_data = results_all[method_name] + result_data = all_results[method_name] if result_data["final_state"] is None or result_data["steps"] is None: # steps_data_fig1[:, method_idx] = np.nan # Already initialized with NaN print(f"Plot1: Skipping steps for {method_name} as final_state or steps is None.") continue - + steps_tensor = result_data["steps"].cpu().numpy() - penalty_steps = ase_max_optimizer_steps + 100 + penalty_steps = ase_max_optimizer_steps + 100 steps_plot_values = np.where(steps_tensor == -1, penalty_steps, steps_tensor) if len(steps_plot_values) == num_structures_plot: steps_data_fig1[:, method_idx] = steps_plot_values elif len(steps_plot_values) > num_structures_plot: - print(f"Warning: More step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Truncating.") + print( + f"Warning: More step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Truncating." + ) steps_data_fig1[:, method_idx] = steps_plot_values[:num_structures_plot] elif len(steps_plot_values) < num_structures_plot: - print(f"Warning: Fewer step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Padding with NaN.") - steps_data_fig1[:len(steps_plot_values), method_idx] = steps_plot_values + print( + f"Warning: Fewer step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Padding with NaN." + ) + steps_data_fig1[: len(steps_plot_values), method_idx] = steps_plot_values # The rest will remain NaN due to initialization -fig1, ax1 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend -x_fig1 = np.arange(num_structures_plot) -width_fig1 = 0.8 / num_methods_fig1 +fig1, ax1 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend +x_fig1 = np.arange(num_structures_plot) +width_fig1 = 0.8 / num_methods_fig1 rects_all_fig1 = [] for i in range(num_methods_fig1): - rects = ax1.bar(x_fig1 - 0.4 + (i + 0.5) * width_fig1, steps_data_fig1[:, i], width_fig1, label=plot_methods_fig1[i]) + rects = ax1.bar( + x_fig1 - 0.4 + (i + 0.5) * width_fig1, + steps_data_fig1[:, i], + width_fig1, + label=plot_methods_fig1[i], + ) rects_all_fig1.append(rects) for bar_idx, bar_val in enumerate(steps_data_fig1[:, i]): - if bar_idx < len(results_all[plot_methods_fig1[i]]["steps"]): # Check bounds - original_step_val = results_all[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] - if original_step_val == -1 and not np.isnan(bar_val): # Check if it was a penalty bar - ax1.text(rects[bar_idx].get_x() + rects[bar_idx].get_width() / 2., - rects[bar_idx].get_height() - 10, - 'NC', ha='center', va='top', color='white', fontsize=7, weight='bold') - -ax1.set_ylabel('Convergence Steps (NC = Not Converged, shown at penalty)') -ax1.set_xlabel('Structure') -ax1.set_title('Convergence Steps per Structure and Method') + if bar_idx < len(all_results[plot_methods_fig1[i]]["steps"]): # Check bounds + original_step_val = ( + all_results[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] + ) + if original_step_val == -1 and not np.isnan( + bar_val + ): # Check if it was a penalty bar + ax1.text( + rects[bar_idx].get_x() + rects[bar_idx].get_width() / 2.0, + rects[bar_idx].get_height() - 10, + "NC", + ha="center", + va="top", + color="white", + fontsize=7, + weight="bold", + ) + +ax1.set_ylabel("Convergence Steps (NC = Not Converged, shown at penalty)") +ax1.set_xlabel("Structure") +ax1.set_title("Convergence Steps per Structure and Method") ax1.set_xticks(x_fig1) ax1.set_xticklabels(structure_names, rotation=45, ha="right") -ax1.legend(title="Optimization Method", bbox_to_anchor=(1.02, 1), loc='upper left') # Adjusted legend position -ax1.grid(True, axis='y', linestyle='--', alpha=0.7) -plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend +ax1.legend( + title="Optimization Method", bbox_to_anchor=(1.02, 1), loc="upper left" +) # Adjusted legend position +ax1.grid(True, axis="y", linestyle="--", alpha=0.7) +plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend # --- Plot 2: Average Final Energy Difference from Baselines --- @@ -580,37 +684,51 @@ def run_optimization( avg_energy_diffs_fig2 = [] plot_names_fig2 = [] -baseline_pos_only_data = results_all.get(baseline_ase_pos_only) -baseline_frechet_data = results_all.get(baseline_ase_frechet) +baseline_pos_only_data = all_results.get(baseline_ase_pos_only) +baseline_frechet_data = all_results.get(baseline_ase_frechet) -for name, result_data in results_all.items(): +for name, result_data in all_results.items(): if result_data["final_state"] is None or result_data["final_state"].energy is None: print(f"Plot2: Skipping energy diff for {name} as final_state or energy is None.") - if name not in plot_names_fig2: plot_names_fig2.append(name) # Keep name for consistent bar count - avg_energy_diffs_fig2.append(np.nan) # Add NaN if data missing + if name not in plot_names_fig2: + plot_names_fig2.append(name) # Keep name for consistent bar count + avg_energy_diffs_fig2.append(np.nan) # Add NaN if data missing continue # Ensure name is added if not already by a skip - if name not in plot_names_fig2: plot_names_fig2.append(name) - + if name not in plot_names_fig2: + plot_names_fig2.append(name) + current_energies = result_data["final_state"].energy.cpu().numpy() - + chosen_baseline_energies = None is_baseline_self = False processed_current_name = False if name == baseline_ase_pos_only or name == baseline_ase_frechet: - avg_energy_diffs_fig2.append(0.0) + avg_energy_diffs_fig2.append(0.0) is_baseline_self = True processed_current_name = True elif "torch-sim" in name: if "PosOnly" in name: - if baseline_pos_only_data and baseline_pos_only_data["final_state"] is not None and baseline_pos_only_data["final_state"].energy is not None: - chosen_baseline_energies = baseline_pos_only_data["final_state"].energy.cpu().numpy() + if ( + baseline_pos_only_data + and baseline_pos_only_data["final_state"] is not None + and baseline_pos_only_data["final_state"].energy is not None + ): + chosen_baseline_energies = ( + baseline_pos_only_data["final_state"].energy.cpu().numpy() + ) elif "Frechet Cell" in name: - if baseline_frechet_data and baseline_frechet_data["final_state"] is not None and baseline_frechet_data["final_state"].energy is not None: - chosen_baseline_energies = baseline_frechet_data["final_state"].energy.cpu().numpy() - + if ( + baseline_frechet_data + and baseline_frechet_data["final_state"] is not None + and baseline_frechet_data["final_state"].energy is not None + ): + chosen_baseline_energies = ( + baseline_frechet_data["final_state"].energy.cpu().numpy() + ) + if not is_baseline_self and not processed_current_name: if chosen_baseline_energies is not None: if current_energies.shape == chosen_baseline_energies.shape: @@ -618,12 +736,20 @@ def run_optimization( avg_energy_diffs_fig2.append(energy_diff) else: avg_energy_diffs_fig2.append(np.nan) - print(f"Plot2: Shape mismatch for energy comparison: {name} vs its baseline. " - f"{current_energies.shape} vs {chosen_baseline_energies.shape}") + print( + f"Plot2: Shape mismatch for energy comparison: {name} vs its baseline. " + f"{current_energies.shape} vs {chosen_baseline_energies.shape}" + ) else: - print(f"Plot2: No appropriate baseline for {name} or baseline data missing. Setting energy diff to NaN.") + print( + f"Plot2: No appropriate baseline for {name} or baseline data missing. Setting energy diff to NaN." + ) avg_energy_diffs_fig2.append(np.nan) - elif not processed_current_name and name not in [n for n,v in zip(plot_names_fig2, avg_energy_diffs_fig2) if not np.isnan(v)] : # Handle cases not covered + elif not processed_current_name and name not in [ + n + for n, v in zip(plot_names_fig2, avg_energy_diffs_fig2, strict=False) + if not np.isnan(v) + ]: # Handle cases not covered print(f"Plot2: Fallback for {name}, setting energy diff to NaN.") avg_energy_diffs_fig2.append(np.nan) @@ -633,84 +759,129 @@ def run_optimization( # A more robust way is to build them in parallel. final_plot_names_fig2 = [] final_avg_energy_diffs_fig2 = [] -all_method_names_sorted = sorted(list(results_all.keys())) # Use a fixed order +all_method_names_sorted = sorted(list(all_results.keys())) # Use a fixed order for name in all_method_names_sorted: - result_data = results_all[name] + result_data = all_results[name] final_plot_names_fig2.append(name) if result_data["final_state"] is None or result_data["final_state"].energy is None: final_avg_energy_diffs_fig2.append(np.nan) continue - + current_energies = result_data["final_state"].energy.cpu().numpy() - energy_to_append = np.nan # Default to NaN + energy_to_append = np.nan # Default to NaN if name == baseline_ase_pos_only or name == baseline_ase_frechet: energy_to_append = 0.0 elif "torch-sim" in name: baseline_to_use_energies = None if "PosOnly" in name: - if baseline_pos_only_data and baseline_pos_only_data["final_state"] is not None and baseline_pos_only_data["final_state"].energy is not None: - baseline_to_use_energies = baseline_pos_only_data["final_state"].energy.cpu().numpy() + if ( + baseline_pos_only_data + and baseline_pos_only_data["final_state"] is not None + and baseline_pos_only_data["final_state"].energy is not None + ): + baseline_to_use_energies = ( + baseline_pos_only_data["final_state"].energy.cpu().numpy() + ) elif "Frechet Cell" in name: - if baseline_frechet_data and baseline_frechet_data["final_state"] is not None and baseline_frechet_data["final_state"].energy is not None: - baseline_to_use_energies = baseline_frechet_data["final_state"].energy.cpu().numpy() - + if ( + baseline_frechet_data + and baseline_frechet_data["final_state"] is not None + and baseline_frechet_data["final_state"].energy is not None + ): + baseline_to_use_energies = ( + baseline_frechet_data["final_state"].energy.cpu().numpy() + ) + if baseline_to_use_energies is not None: if current_energies.shape == baseline_to_use_energies.shape: energy_to_append = np.mean(current_energies - baseline_to_use_energies) else: - print(f"Plot2: Shape mismatch for {name} ({current_energies.shape}) vs baseline ({baseline_to_use_energies.shape}).") + print( + f"Plot2: Shape mismatch for {name} ({current_energies.shape}) vs baseline ({baseline_to_use_energies.shape})." + ) final_avg_energy_diffs_fig2.append(energy_to_append) fig2, ax2 = plt.subplots(figsize=(12, 7)) -bars_fig2 = ax2.bar(final_plot_names_fig2, final_avg_energy_diffs_fig2, color='lightcoral') -ax2.set_ylabel('Avg. Final Energy Diff. from Corresponding ASE Baseline (eV)') -ax2.set_xlabel('Optimization Method') -ax2.set_title('Average Final Energy Difference from ASE Baselines') -ax2.axhline(0, color='black', linewidth=0.8, linestyle='--') +bars_fig2 = ax2.bar( + final_plot_names_fig2, final_avg_energy_diffs_fig2, color="lightcoral" +) +ax2.set_ylabel("Avg. Final Energy Diff. from Corresponding ASE Baseline (eV)") +ax2.set_xlabel("Optimization Method") +ax2.set_title("Average Final Energy Difference from ASE Baselines") +ax2.axhline(0, color="black", linewidth=0.8, linestyle="--") for bar in bars_fig2: yval = bar.get_height() - if not np.isnan(yval): - text_y_offset = 0.001 if yval >= 0 else -0.005 - va_align = 'bottom' if yval >=0 else 'top' - ax2.text(bar.get_x() + bar.get_width()/2.0, yval + text_y_offset, - f"{yval:.3f}", ha='center', va=va_align, fontsize=8, color='black') + if not np.isnan(yval): + text_y_offset = 0.001 if yval >= 0 else -0.005 + va_align = "bottom" if yval >= 0 else "top" + ax2.text( + bar.get_x() + bar.get_width() / 2.0, + yval + text_y_offset, + f"{yval:.3f}", + ha="center", + va=va_align, + fontsize=8, + color="black", + ) plt.xticks(rotation=45, ha="right") plt.tight_layout() # --- Plot 3: Mean Displacement from ASE Counterparts (Multi-bar per structure) --- -comparison_pairs_plot3_defs = [ # (ts_method_name, ase_method_name, short_label_for_legend) - ("torch-sim ASE-FIRE (PosOnly)", baseline_ase_pos_only, "TS ASE PosOnly vs ASE Native"), +comparison_pairs_plot3_defs = [ # (ts_method_name, ase_method_name, short_label_for_legend) + ( + "torch-sim ASE-FIRE (PosOnly)", + baseline_ase_pos_only, + "TS ASE PosOnly vs ASE Native", + ), ("torch-sim VV-FIRE (PosOnly)", baseline_ase_pos_only, "TS VV PosOnly vs ASE Native"), - ("torch-sim ASE-FIRE (Frechet Cell)", baseline_ase_frechet, "TS ASE Frechet vs ASE Frechet"), - ("torch-sim VV-FIRE (Frechet Cell)", baseline_ase_frechet, "TS VV Frechet vs ASE Frechet"), + ( + "torch-sim ASE-FIRE (Frechet Cell)", + baseline_ase_frechet, + "TS ASE Frechet vs ASE Frechet", + ), + ( + "torch-sim VV-FIRE (Frechet Cell)", + baseline_ase_frechet, + "TS VV Frechet vs ASE Frechet", + ), ] num_comparison_pairs_plot3 = len(comparison_pairs_plot3_defs) # rows: structures, cols: comparison_pair -disp_data_fig3 = np.full((num_structures_plot, num_comparison_pairs_plot3), np.nan) +disp_data_fig3 = np.full((num_structures_plot, num_comparison_pairs_plot3), np.nan) legend_labels_fig3 = [] -for pair_idx, (ts_method_name, ase_method_name, plot_label) in enumerate(comparison_pairs_plot3_defs): +for pair_idx, (ts_method_name, ase_method_name, plot_label) in enumerate( + comparison_pairs_plot3_defs +): legend_labels_fig3.append(plot_label) - if ts_method_name in results_all and ase_method_name in results_all: - state1_data = results_all[ts_method_name] - state2_data = results_all[ase_method_name] + if ts_method_name in all_results and ase_method_name in all_results: + state1_data = all_results[ts_method_name] + state2_data = all_results[ase_method_name] if state1_data["final_state"] is None or state2_data["final_state"] is None: - print(f"Plot3: Skipping displacement for {plot_label} due to missing state data.") + print( + f"Plot3: Skipping displacement for {plot_label} due to missing state data." + ) # Data remains NaN continue - + state1_list = state1_data["final_state"].split() state2_list = state2_data["final_state"].split() - - if len(state1_list) != len(state2_list) or len(state1_list) != num_structures_plot : - print(f"Plot3: Structure count mismatch for {plot_label}. Expected {num_structures_plot}, got S1:{len(state1_list)}, S2:{len(state2_list)}") + + if ( + len(state1_list) != len(state2_list) + or len(state1_list) != num_structures_plot + ): + print( + f"Plot3: Structure count mismatch for {plot_label}. " + f"Expected {num_structures_plot}, got S1:{len(state1_list)}, S2:{len(state2_list)}" + ) # Data remains NaN continue @@ -724,31 +895,40 @@ def run_optimization( displacement = torch.norm(pos1_centered - pos2_centered, dim=1) mean_disp = torch.mean(displacement).item() mean_displacements_for_this_pair.append(mean_disp) - + if len(mean_displacements_for_this_pair) == num_structures_plot: disp_data_fig3[:, pair_idx] = np.array(mean_displacements_for_this_pair) - else: # Should not happen if previous checks pass + else: # Should not happen if previous checks pass print(f"Plot3: Inner loop displacement calculation mismatch for {plot_label}") else: print(f"Plot3: Missing data for methods in pair: {plot_label}.") # Data remains NaN -fig3, ax3 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend -x_fig3 = np.arange(num_structures_plot) +fig3, ax3 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend +x_fig3 = np.arange(num_structures_plot) width_fig3 = 0.8 / num_comparison_pairs_plot3 for i in range(num_comparison_pairs_plot3): - ax3.bar(x_fig3 - 0.4 + (i + 0.5) * width_fig3, disp_data_fig3[:, i], width_fig3, label=legend_labels_fig3[i]) + ax3.bar( + x_fig3 - 0.4 + (i + 0.5) * width_fig3, + disp_data_fig3[:, i], + width_fig3, + label=legend_labels_fig3[i], + ) -ax3.set_ylabel('Mean Atomic Displacement (Å) to ASE Counterpart') -ax3.set_xlabel('Structure') -ax3.set_title('Mean Displacement of Torch-Sim Methods to ASE Counterparts (per Structure)') +ax3.set_ylabel("Mean Atomic Displacement (Å) to ASE Counterpart") +ax3.set_xlabel("Structure") +ax3.set_title( + "Mean Displacement of Torch-Sim Methods to ASE Counterparts (per Structure)" +) ax3.set_xticks(x_fig3) ax3.set_xticklabels(structure_names, rotation=45, ha="right") -ax3.legend(title="Comparison Pair", bbox_to_anchor=(1.02, 1), loc='upper left') # Adjusted legend -ax3.grid(True, axis='y', linestyle='--', alpha=0.7) -plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend +ax3.legend( + title="Comparison Pair", bbox_to_anchor=(1.02, 1), loc="upper left" +) # Adjusted legend +ax3.grid(True, axis="y", linestyle="--", alpha=0.7) +plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend plt.show() From f4c0ba8388d5d9b53f98e4f3407e49252b28d2ce Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 21 May 2025 13:20:33 -0400 Subject: [PATCH 6/7] swap matplotlib for plotly in 7.6_Compare_ASE_to_VV_FIRE.py --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 232 +++++++++--------- examples/tutorials/low_level_tutorial.py | 2 +- tests/models/test_graphpes.py | 6 +- torch_sim/models/sevennet.py | 2 +- torch_sim/properties/correlations.py | 2 +- 5 files changed, 120 insertions(+), 124 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index 8ae4b78b..f550642b 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -12,8 +12,8 @@ import time from typing import Literal -import matplotlib.pyplot as plt import numpy as np +import plotly.graph_objects as go import torch from ase.build import bulk from ase.cell import Cell @@ -598,13 +598,14 @@ def run_optimization_ase( ] # Updated for 6 structures if len(structure_names) != len(atoms_list): print( - f"Warning: Mismatch between custom structure_names ({len(structure_names)}) and atoms_list ({len(atoms_list)}). Using custom names." + f"Warning: Mismatch between custom structure_names ({len(structure_names)}) and " + f"atoms_list ({len(atoms_list)}). Using custom names." ) num_structures_plot = len(structure_names) # --- Plot 1: Convergence Steps (Multi-bar per structure) --- -plot_methods_fig1 = list(all_results.keys()) +plot_methods_fig1 = list(all_results) num_methods_fig1 = len(plot_methods_fig1) # Initialize with NaNs, so if a method fails completely, its bars are missing or clearly marked steps_data_fig1 = np.full((num_structures_plot, num_methods_fig1), np.nan) @@ -634,48 +635,42 @@ def run_optimization_ase( steps_data_fig1[: len(steps_plot_values), method_idx] = steps_plot_values # The rest will remain NaN due to initialization -fig1, ax1 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend -x_fig1 = np.arange(num_structures_plot) -width_fig1 = 0.8 / num_methods_fig1 +fig1_plotly = go.Figure() -rects_all_fig1 = [] for i in range(num_methods_fig1): - rects = ax1.bar( - x_fig1 - 0.4 + (i + 0.5) * width_fig1, - steps_data_fig1[:, i], - width_fig1, - label=plot_methods_fig1[i], + fig1_plotly.add_bar( + name=plot_methods_fig1[i], + x=structure_names, # x-axis is structure names + y=steps_data_fig1[:, i], + text=[ + "NC" + if all_results[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] == -1 + and not np.isnan(steps_data_fig1[bar_idx, i]) + else "" + for bar_idx in range(num_structures_plot) + ], + textposition="inside", + insidetextanchor="middle", + textfont=dict(color="white", size=10, family="Arial, sans-serif"), ) - rects_all_fig1.append(rects) - for bar_idx, bar_val in enumerate(steps_data_fig1[:, i]): - if bar_idx < len(all_results[plot_methods_fig1[i]]["steps"]): # Check bounds - original_step_val = ( - all_results[plot_methods_fig1[i]]["steps"].cpu().numpy()[bar_idx] - ) - if original_step_val == -1 and not np.isnan( - bar_val - ): # Check if it was a penalty bar - ax1.text( - rects[bar_idx].get_x() + rects[bar_idx].get_width() / 2.0, - rects[bar_idx].get_height() - 10, - "NC", - ha="center", - va="top", - color="white", - fontsize=7, - weight="bold", - ) -ax1.set_ylabel("Convergence Steps (NC = Not Converged, shown at penalty)") -ax1.set_xlabel("Structure") -ax1.set_title("Convergence Steps per Structure and Method") -ax1.set_xticks(x_fig1) -ax1.set_xticklabels(structure_names, rotation=45, ha="right") -ax1.legend( - title="Optimization Method", bbox_to_anchor=(1.02, 1), loc="upper left" -) # Adjusted legend position -ax1.grid(True, axis="y", linestyle="--", alpha=0.7) -plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend +fig1_plotly.update_layout( + barmode="group", + title_text="Convergence Steps per Structure and Method", + xaxis_title="Structure", + yaxis_title="Convergence Steps (NC = Not Converged, shown at penalty)", + legend_title="Optimization Method", + xaxis_tickangle=-45, + yaxis_gridcolor="lightgrey", + plot_bgcolor="white", + height=600, + width=max( + 1000, 150 * num_structures_plot + ), # Adjust width based on number of structures + margin=dict(l=50, r=50, b=100, t=50, pad=4), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), +) +fig1_plotly.update_xaxes(categoryorder="array", categoryarray=structure_names) # --- Plot 2: Average Final Energy Difference from Baselines --- @@ -705,7 +700,7 @@ def run_optimization_ase( is_baseline_self = False processed_current_name = False - if name == baseline_ase_pos_only or name == baseline_ase_frechet: + if name in (baseline_ase_pos_only, baseline_ase_frechet): avg_energy_diffs_fig2.append(0.0) is_baseline_self = True processed_current_name = True @@ -719,15 +714,14 @@ def run_optimization_ase( chosen_baseline_energies = ( baseline_pos_only_data["final_state"].energy.cpu().numpy() ) - elif "Frechet Cell" in name: - if ( - baseline_frechet_data - and baseline_frechet_data["final_state"] is not None - and baseline_frechet_data["final_state"].energy is not None - ): - chosen_baseline_energies = ( - baseline_frechet_data["final_state"].energy.cpu().numpy() - ) + elif "Frechet Cell" in name and ( + baseline_frechet_data + and baseline_frechet_data["final_state"] is not None + and baseline_frechet_data["final_state"].energy is not None + ): + chosen_baseline_energies = ( + baseline_frechet_data["final_state"].energy.cpu().numpy() + ) if not is_baseline_self and not processed_current_name: if chosen_baseline_energies is not None: @@ -737,12 +731,13 @@ def run_optimization_ase( else: avg_energy_diffs_fig2.append(np.nan) print( - f"Plot2: Shape mismatch for energy comparison: {name} vs its baseline. " + f"Plot2: Shape mismatch for energy comparison: {name} vs baseline. " f"{current_energies.shape} vs {chosen_baseline_energies.shape}" ) else: print( - f"Plot2: No appropriate baseline for {name} or baseline data missing. Setting energy diff to NaN." + f"Plot2: No appropriate baseline for {name} or baseline data missing. " + "Setting energy diff to NaN." ) avg_energy_diffs_fig2.append(np.nan) elif not processed_current_name and name not in [ @@ -759,7 +754,7 @@ def run_optimization_ase( # A more robust way is to build them in parallel. final_plot_names_fig2 = [] final_avg_energy_diffs_fig2 = [] -all_method_names_sorted = sorted(list(all_results.keys())) # Use a fixed order +all_method_names_sorted = sorted(all_results) # Use a fixed order for name in all_method_names_sorted: result_data = all_results[name] @@ -771,7 +766,7 @@ def run_optimization_ase( current_energies = result_data["final_state"].energy.cpu().numpy() energy_to_append = np.nan # Default to NaN - if name == baseline_ase_pos_only or name == baseline_ase_frechet: + if name in (baseline_ase_pos_only, baseline_ase_frechet): energy_to_append = 0.0 elif "torch-sim" in name: baseline_to_use_energies = None @@ -784,52 +779,53 @@ def run_optimization_ase( baseline_to_use_energies = ( baseline_pos_only_data["final_state"].energy.cpu().numpy() ) - elif "Frechet Cell" in name: - if ( - baseline_frechet_data - and baseline_frechet_data["final_state"] is not None - and baseline_frechet_data["final_state"].energy is not None - ): - baseline_to_use_energies = ( - baseline_frechet_data["final_state"].energy.cpu().numpy() - ) + elif "Frechet Cell" in name and ( + baseline_frechet_data + and baseline_frechet_data["final_state"] is not None + and baseline_frechet_data["final_state"].energy is not None + ): + baseline_to_use_energies = ( + baseline_frechet_data["final_state"].energy.cpu().numpy() + ) if baseline_to_use_energies is not None: if current_energies.shape == baseline_to_use_energies.shape: energy_to_append = np.mean(current_energies - baseline_to_use_energies) else: print( - f"Plot2: Shape mismatch for {name} ({current_energies.shape}) vs baseline ({baseline_to_use_energies.shape})." + f"Plot2: Shape mismatch for {name} ({current_energies.shape}) " + f"vs baseline ({baseline_to_use_energies.shape})." ) final_avg_energy_diffs_fig2.append(energy_to_append) -fig2, ax2 = plt.subplots(figsize=(12, 7)) -bars_fig2 = ax2.bar( - final_plot_names_fig2, final_avg_energy_diffs_fig2, color="lightcoral" +fig2_plotly = go.Figure() +fig2_plotly.add_bar( + x=final_plot_names_fig2, + y=final_avg_energy_diffs_fig2, + marker_color="lightcoral", + text=[ + f"{yval:.3f}" if not np.isnan(yval) else "" + for yval in final_avg_energy_diffs_fig2 + ], + textposition="auto", # Let Plotly decide best position, or use 'outside'/'inside' + textfont=dict(size=10), ) -ax2.set_ylabel("Avg. Final Energy Diff. from Corresponding ASE Baseline (eV)") -ax2.set_xlabel("Optimization Method") -ax2.set_title("Average Final Energy Difference from ASE Baselines") -ax2.axhline(0, color="black", linewidth=0.8, linestyle="--") - -for bar in bars_fig2: - yval = bar.get_height() - if not np.isnan(yval): - text_y_offset = 0.001 if yval >= 0 else -0.005 - va_align = "bottom" if yval >= 0 else "top" - ax2.text( - bar.get_x() + bar.get_width() / 2.0, - yval + text_y_offset, - f"{yval:.3f}", - ha="center", - va=va_align, - fontsize=8, - color="black", - ) -plt.xticks(rotation=45, ha="right") -plt.tight_layout() +line_dict = dict(color="black", width=1, dash="dash") +x1 = len(final_plot_names_fig2) - 0.5 +fig2_plotly.update_layout( + title_text="Average Final Energy Difference from ASE Baselines", + xaxis_title="Optimization Method", + yaxis_title="Avg. Final Energy Diff. from Corresponding ASE Baseline (eV)", + xaxis_tickangle=-45, + yaxis_gridcolor="lightgrey", + plot_bgcolor="white", + shapes=[dict(type="line", y0=0, y1=0, x0=-0.5, x1=x1, line=line_dict)], + height=600, + width=max(800, 100 * len(final_plot_names_fig2)), + margin=dict(l=50, r=50, b=150, t=50, pad=4), # Increased bottom margin for labels +) # --- Plot 3: Mean Displacement from ASE Counterparts (Multi-bar per structure) --- @@ -866,9 +862,8 @@ def run_optimization_ase( if state1_data["final_state"] is None or state2_data["final_state"] is None: print( - f"Plot3: Skipping displacement for {plot_label} due to missing state data." + f"Plot3: Skipping displacement for {plot_label} due to missing state data" ) - # Data remains NaN continue state1_list = state1_data["final_state"].split() @@ -880,13 +875,13 @@ def run_optimization_ase( ): print( f"Plot3: Structure count mismatch for {plot_label}. " - f"Expected {num_structures_plot}, got S1:{len(state1_list)}, S2:{len(state2_list)}" + f"Expected {num_structures_plot}, got S1:{len(state1_list)}, " + f"S2:{len(state2_list)}" ) - # Data remains NaN continue mean_displacements_for_this_pair = [] - for s_idx, (s1, s2) in enumerate(zip(state1_list, state2_list, strict=True)): + for s1, s2 in zip(state1_list, state2_list, strict=True): if s1.n_atoms == 0 or s2.n_atoms == 0 or s1.n_atoms != s2.n_atoms: mean_displacements_for_this_pair.append(np.nan) continue @@ -905,30 +900,31 @@ def run_optimization_ase( print(f"Plot3: Missing data for methods in pair: {plot_label}.") # Data remains NaN -fig3, ax3 = plt.subplots(figsize=(17, 8)) # Wider for 6 structures + legend -x_fig3 = np.arange(num_structures_plot) -width_fig3 = 0.8 / num_comparison_pairs_plot3 - -for i in range(num_comparison_pairs_plot3): - ax3.bar( - x_fig3 - 0.4 + (i + 0.5) * width_fig3, - disp_data_fig3[:, i], - width_fig3, - label=legend_labels_fig3[i], - ) - -ax3.set_ylabel("Mean Atomic Displacement (Å) to ASE Counterpart") -ax3.set_xlabel("Structure") -ax3.set_title( - "Mean Displacement of Torch-Sim Methods to ASE Counterparts (per Structure)" +fig3_plotly = go.Figure() + +for idx, name in enumerate(legend_labels_fig3): + # x-axis is structure names + fig3_plotly.add_bar(name=name, x=structure_names, y=disp_data_fig3[:, idx]) + + +title = "Mean Displacement of Torch-Sim Methods to ASE Counterparts (per Structure)" +fig3_plotly.update_layout( + barmode="group", + title=dict(text=title, x=0.5, y=1), + xaxis_title="Structure", + yaxis_title="Mean Atomic Displacement (Å) to ASE Counterpart", + legend_title="Comparison Pair", + xaxis_tickangle=-45, + yaxis_gridcolor="lightgrey", + plot_bgcolor="white", + height=600, + width=max(1000, 150 * num_structures_plot), # Adjust width + margin=dict(l=50, r=50, b=100, t=50, pad=4), + legend=dict(orientation="h", yanchor="bottom", y=0.96, xanchor="right", x=1), ) -ax3.set_xticks(x_fig3) -ax3.set_xticklabels(structure_names, rotation=45, ha="right") -ax3.legend( - title="Comparison Pair", bbox_to_anchor=(1.02, 1), loc="upper left" -) # Adjusted legend -ax3.grid(True, axis="y", linestyle="--", alpha=0.7) -plt.tight_layout(rect=[0, 0, 0.83, 1]) # Fine-tune rect for legend -plt.show() +# Show Plotly figures +fig1_plotly.show() +fig2_plotly.show() +fig3_plotly.show() diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index 3470119c..f12893c3 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -114,7 +114,7 @@ # %% model_outputs = model(state) -print(f"Model outputs: {', '.join(list(model_outputs.keys()))}") +print(f"Model outputs: {', '.join(list(model_outputs))}") print(f"Energy is a batchwise property with shape: {model_outputs['energy'].shape}") print(f"Forces are an atomwise property with shape: {model_outputs['forces'].shape}") diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index 06a76dca..1470d3be 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -44,7 +44,7 @@ def test_graphpes_isolated(device: torch.device): compute_stress=False, ) ts_output = ts_model(ts.io.atoms_to_state([water_atoms], device, torch.float32)) - assert set(ts_output.keys()) == {"energy", "forces"} + assert set(ts_output) == {"energy", "forces"} assert ts_output["energy"].shape == (1,) assert gp_energy.item() == pytest.approx(ts_output["energy"].item(), abs=1e-5) @@ -69,7 +69,7 @@ def test_graphpes_periodic(device: torch.device): compute_stress=True, ) ts_output = ts_model(ts.io.atoms_to_state([bulk_atoms], device, torch.float32)) - assert set(ts_output.keys()) == {"energy", "forces", "stress"} + assert set(ts_output) == {"energy", "forces", "stress"} assert ts_output["energy"].shape == (1,) assert ts_output["forces"].shape == (len(bulk_atoms), 3) assert ts_output["stress"].shape == (1, 3, 3) @@ -101,7 +101,7 @@ def test_batching(device: torch.device): ) ts_output = ts_model(ts.io.atoms_to_state(systems, device, torch.float32)) - assert set(ts_output.keys()) == {"energy", "forces", "stress"} + assert set(ts_output) == {"energy", "forces", "stress"} assert ts_output["energy"].shape == (2,) assert ts_output["forces"].shape == (sum(len(s) for s in systems), 3) assert ts_output["stress"].shape == (2, 3, 3) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 1ac35798..45e59f9b 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -125,7 +125,7 @@ def __init__( self.modal = None modal_map = self.model.modal_map if modal_map: - modal_ava = list(modal_map.keys()) + modal_ava = list(modal_map) if not modal: raise ValueError(f"modal argument missing (avail: {modal_ava})") if modal not in modal_ava: diff --git a/torch_sim/properties/correlations.py b/torch_sim/properties/correlations.py index 6cfbe579..2dda340d 100644 --- a/torch_sim/properties/correlations.py +++ b/torch_sim/properties/correlations.py @@ -275,7 +275,7 @@ def _compute_correlations(self) -> None: # noqa: C901, PLR0915 self.correlations[name] = acf # Cross-correlations - names = list(self.buffers.keys()) + names = list(self.buffers) for i, name1 in enumerate(names): for name2 in names[i + 1 :]: data1 = self.buffers[name1].get_array() From 9fc4172ba14ce1af9f93424905111d7ceaf72b6f Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 22 May 2025 16:25:43 -0400 Subject: [PATCH 7/7] lint --- .../7_Others/7.6_Compare_ASE_to_VV_FIRE.py | 136 ++++++++---------- 1 file changed, 61 insertions(+), 75 deletions(-) diff --git a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py index f550642b..2fc90f3b 100644 --- a/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py +++ b/examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py @@ -5,6 +5,7 @@ # /// script # dependencies = [ # "mace-torch>=0.3.12", +# "plotly>=6.0.0", # ] # /// @@ -49,6 +50,9 @@ # Set random seed for reproducibility rng = np.random.default_rng(seed=0) +torch.manual_seed(0) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.21, cubic=True).repeat(supercell_scale) @@ -120,7 +124,7 @@ initial_energies = model(state)["energy"] -def run_optimization_ts( +def run_optimization_ts( # noqa: PLR0915 *, initial_state: SimState, ts_md_flavor: Literal["vv_fire", "ase_fire"], @@ -141,25 +145,24 @@ def run_optimization_ts( ase_cell_k = Cell(cell_tensor_k) params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) print( - f" Structure {k_idx + 1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]" + f" Structure {k_idx + 1}: Volume={ase_cell_k.volume:.2f} ų, " + f"Params=[{params_str}]" ) if ts_use_frechet: - # Uses frechet_cell_fire for combined cell and position optimization init_fn_opt, update_fn_opt = frechet_cell_fire( model=model, md_flavor=ts_md_flavor ) else: - # Uses fire for position-only optimization init_fn_opt, update_fn_opt = fire(model=model, md_flavor=ts_md_flavor) opt_state = init_fn_opt(initial_state.clone()) batcher = ts.InFlightAutoBatcher( - model=model, # The MaceModel wrapper + model=model, memory_scales_with="n_atoms", max_memory_scaler=1000, - max_iterations=max_iterations_ts, # Use the passed max_iterations + max_iterations=max_iterations_ts, return_indices=True, ) batcher.load_states(opt_state) @@ -229,11 +232,13 @@ def run_optimization_ts( ase_cell_k = Cell(cell_tensor_k) params_str = ", ".join([f"{p:.2f}" for p in ase_cell_k.cellpar()]) print( - f" Structure {k_idx + 1}: Volume={ase_cell_k.volume:.2f} ų, Params=[{params_str}]" + f" Structure {k_idx + 1}: Volume={ase_cell_k.volume:.2f} ų, " + f"Params=[{params_str}]" ) else: print( - "Final cell parameters (Torch-Sim): Not available (final_state_concatenated is None or has no cell)." + "Final cell parameters (Torch-Sim): Not available (final_state_concatenated " + "is None or has no cell)." ) end_time = time.perf_counter() @@ -244,7 +249,7 @@ def run_optimization_ts( return convergence_steps, final_state_concatenated -def run_optimization_ase( +def run_optimization_ase( # noqa: C901, PLR0915 *, initial_state: SimState, ase_use_frechet_filter: bool, @@ -270,7 +275,9 @@ def run_optimization_ase( initial_cell_ase = ase_atoms_orig.get_cell() initial_params_str = ", ".join([f"{p:.2f}" for p in initial_cell_ase.cellpar()]) print( - f" Initial cell (ASE Structure {i + 1}): Volume={initial_cell_ase.volume:.2f} ų, Params=[{initial_params_str}]" + f" Initial cell (ASE Structure {i + 1}): " + f"Volume={initial_cell_ase.volume:.2f} ų, " + f"Params=[{initial_params_str}]" ) ase_calc_instance = mace_mp_calculator_for_ase( @@ -288,7 +295,7 @@ def run_optimization_ase( dyn = ASEFIRE(optim_target_atoms, trajectory=None, logfile=None) try: - dyn.run(fmax=force_tol, steps=max_steps_ase) # Use passed max_steps_ase + dyn.run(fmax=force_tol, steps=max_steps_ase) if dyn.converged(): convergence_steps_list.append(dyn.nsteps) print(f"ASE structure {i + 1} converged in {dyn.nsteps} steps.") @@ -298,7 +305,7 @@ def run_optimization_ase( f"{max_steps_ase} steps. Steps taken: {dyn.nsteps}." ) convergence_steps_list.append(-1) - except Exception as e: + except Exception as e: # noqa: BLE001 print(f"ASE optimization failed for structure {i + 1}: {e}") convergence_steps_list.append(-1) @@ -308,23 +315,20 @@ def run_optimization_ase( final_cell_ase = final_ats_for_print.get_cell() final_params_str = ", ".join([f"{p:.2f}" for p in final_cell_ase.cellpar()]) print( - f" Final cell (ASE Structure {i + 1}): Volume={final_cell_ase.volume:.2f} ų, Params=[{final_params_str}]" + f" Final cell (ASE Structure {i + 1}): " + f"Volume={final_cell_ase.volume:.2f} ų, " + f"Params=[{final_params_str}]" ) final_ase_atoms_list.append(final_ats_for_print) - # Convert list of final ASE atoms objects back to a base SimState first - # to easily get positions, cell, etc. - # However, ts.io.atoms_to_state might not preserve all attributes needed for GDState directly. - # It's better to extract all required components directly from final_ase_atoms_list. - all_positions = [] all_masses = [] all_atomic_numbers = [] all_cells = [] all_batches_for_gd = [] final_energies_ase = [] - final_forces_ase_tensors = [] # List to store force tensors + final_forces_ase_tensors = [] current_atom_offset = 0 for batch_idx, ats_final in enumerate(final_ase_atoms_list): @@ -353,7 +357,8 @@ def run_optimization_ase( try: if ats_final.calc is None: print( - f"Re-attaching ASE calculator for final energy/forces for structure {batch_idx}." + "Re-attaching ASE calculator for final energy/forces for " + f"structure {batch_idx}." ) temp_calc = mace_mp_calculator_for_ase( model=MaceUrls.mace_mpa_medium, @@ -365,20 +370,14 @@ def run_optimization_ase( final_forces_ase_tensors.append( torch.tensor(ats_final.get_forces(), device=device, dtype=dtype) ) - except Exception as e: + except Exception as e: # noqa: BLE001 print( f"Could not get final energy/forces for an ASE structure {batch_idx}: {e}" ) final_energies_ase.append(float("nan")) - # Append a zero tensor of appropriate shape if forces fail, or handle error - # For GDState, forces are required. If any structure fails, GDState creation might fail. - # We need to ensure all_positions, etc. are also correctly populated even on failure. - # For now, let's assume if energy fails, forces might also, and GDState might be problematic. - # A robust solution would be to skip failed structures or return None. - # For now, let's make forces a zero tensor of expected shape if it fails. if all_positions and len(all_positions[-1]) > 0: final_forces_ase_tensors.append(torch.zeros_like(all_positions[-1])) - else: # Cannot determine shape, this path is problematic + else: final_forces_ase_tensors.append( torch.empty((0, 3), device=device, dtype=dtype) ) @@ -400,18 +399,16 @@ def run_optimization_ase( # Check for NaN energies which might cause issues if torch.isnan(concatenated_energies).any(): print( - "Warning: NaN values found in final ASE energies. GDState energy tensor will contain NaNs." + "Warning: NaN values found in final ASE energies. " + "GDState energy tensor will contain NaNs." ) - # Consider replacing NaNs if GDState or subsequent ops can't handle them: - # concatenated_energies = torch.nan_to_num(concatenated_energies, nan=0.0) # Example replacement # Create GDState instance - # pbc is global, taken from initial_state final_state_as_gd = GDState( positions=concatenated_positions, masses=concatenated_masses, cell=concatenated_cells, - pbc=initial_state.pbc, # Assuming pbc is constant and global + pbc=initial_state.pbc, atomic_numbers=concatenated_atomic_numbers, batch=concatenated_batch_indices, energy=concatenated_energies, @@ -460,7 +457,7 @@ def run_optimization_ase( "ts_use_frechet": True, }, { - "name": "ASE FIRE (Native, PosOnly)", # Corrected name: Only optimizes positions without a cell filter + "name": "ASE FIRE (Native, PosOnly)", "type": "ase", "ase_use_frechet_filter": False, }, @@ -475,9 +472,7 @@ def run_optimization_ase( for config_run in configs_to_run: print(f"\n\nStarting configuration: {config_run['name']}") - # Get relevant params, providing defaults where necessary for the run_optimization call optimizer_type_val = config_run["type"] - # Will be None for ASE type, handled by assert ts_md_flavor_val = config_run.get("ts_md_flavor") ts_use_frechet_val = config_run.get("ts_use_frechet", False) ase_use_frechet_filter_val = config_run.get("ase_use_frechet_filter", False) @@ -488,18 +483,18 @@ def run_optimization_ase( if optimizer_type_val == "torch_sim": assert ts_md_flavor_val is not None, "ts_md_flavor must be provided for torch_sim" steps, final_state_opt = run_optimization_ts( - initial_state=state.clone(), # Use a fresh clone for each run + initial_state=state.clone(), ts_md_flavor=ts_md_flavor_val, ts_use_frechet=ts_use_frechet_val, force_tol=force_tol, - max_iterations_ts=max_iterations, # Pass the global max_iterations + max_iterations_ts=max_iterations, ) elif optimizer_type_val == "ase": steps, final_state_opt = run_optimization_ase( - initial_state=state.clone(), # Use a fresh clone for each run + initial_state=state.clone(), ase_use_frechet_filter=ase_use_frechet_filter_val, force_tol=force_tol, - max_steps_ase=ase_max_optimizer_steps, # Pass the global ase_max_optimizer_steps + max_steps_ase=ase_max_optimizer_steps, ) else: raise ValueError(f"Unknown optimizer_type: {optimizer_type_val}") @@ -530,7 +525,6 @@ def run_optimization_ase( if not_converged_indices: print(f" Did not converge for structure indices: {not_converged_indices}") -# Mean Displacement Comparisons comparison_pairs = [ ("torch-sim ASE-FIRE (PosOnly)", "ASE FIRE (Native, PosOnly)"), ("torch-sim ASE-FIRE (Frechet Cell)", "ASE FIRE (Native Frechet Filter, CellOpt)"), @@ -559,14 +553,15 @@ def run_optimization_ase( mean_displacements = [] for s1, s2 in zip(state1_list, state2_list, strict=True): - if s1.n_atoms == 0 or s2.n_atoms == 0: # Handle empty states if they occur + if s1.n_atoms == 0 or s2.n_atoms == 0: mean_displacements.append(float("nan")) continue pos1_centered = s1.positions - s1.positions.mean(dim=0, keepdim=True) pos2_centered = s2.positions - s2.positions.mean(dim=0, keepdim=True) if pos1_centered.shape != pos2_centered.shape: print( - f"Warning: Shape mismatch for {name1} vs {name2} in structure. Skipping displacement calc." + f"Warning: Shape mismatch for {name1} vs {name2} in structure. " + "Skipping displacement calc." ) mean_displacements.append(float("nan")) continue @@ -575,19 +570,18 @@ def run_optimization_ase( mean_displacements.append(mean_disp) print( - f"\nMean Disp ({name1} vs {name2}): {[f'{d:.4f}' for d in mean_displacements]} Å" + f"\nMean Disp ({name1} vs {name2}): " + f"{[f'{d:.4f}' for d in mean_displacements]} Å" ) else: print( - f"\nSkipping displacement comparison for ({name1} vs {name2}), one or both results missing." + f"\nSkipping displacement comparison for ({name1} vs {name2}), " + "one or both results missing." ) # --- Plotting Results --- - -# Names for the structures for plotting labels original_structure_formulas = [ats.get_chemical_formula() for ats in atoms_list] -# Make them more concise if needed: structure_names = [ "Si_bulk", "Cu_bulk", @@ -595,7 +589,7 @@ def run_optimization_ase( "Si_vac", "Cu_vac", "Fe_vac", -] # Updated for 6 structures +] if len(structure_names) != len(atoms_list): print( f"Warning: Mismatch between custom structure_names ({len(structure_names)}) and " @@ -607,13 +601,11 @@ def run_optimization_ase( # --- Plot 1: Convergence Steps (Multi-bar per structure) --- plot_methods_fig1 = list(all_results) num_methods_fig1 = len(plot_methods_fig1) -# Initialize with NaNs, so if a method fails completely, its bars are missing or clearly marked steps_data_fig1 = np.full((num_structures_plot, num_methods_fig1), np.nan) for method_idx, method_name in enumerate(plot_methods_fig1): result_data = all_results[method_name] if result_data["final_state"] is None or result_data["steps"] is None: - # steps_data_fig1[:, method_idx] = np.nan # Already initialized with NaN print(f"Plot1: Skipping steps for {method_name} as final_state or steps is None.") continue @@ -625,22 +617,25 @@ def run_optimization_ase( steps_data_fig1[:, method_idx] = steps_plot_values elif len(steps_plot_values) > num_structures_plot: print( - f"Warning: More step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Truncating." + f"Warning: More step values ({len(steps_plot_values)}) than " + f"structure names ({num_structures_plot}) for {method_name}. " + "Truncating." ) steps_data_fig1[:, method_idx] = steps_plot_values[:num_structures_plot] elif len(steps_plot_values) < num_structures_plot: print( - f"Warning: Fewer step values ({len(steps_plot_values)}) than structure names ({num_structures_plot}) for {method_name}. Padding with NaN." + f"Warning: Fewer step values ({len(steps_plot_values)}) than " + f"structure names ({num_structures_plot}) for {method_name}. " + "Padding with NaN." ) steps_data_fig1[: len(steps_plot_values), method_idx] = steps_plot_values - # The rest will remain NaN due to initialization fig1_plotly = go.Figure() for i in range(num_methods_fig1): fig1_plotly.add_bar( name=plot_methods_fig1[i], - x=structure_names, # x-axis is structure names + x=structure_names, y=steps_data_fig1[:, i], text=[ "NC" @@ -664,9 +659,7 @@ def run_optimization_ase( yaxis_gridcolor="lightgrey", plot_bgcolor="white", height=600, - width=max( - 1000, 150 * num_structures_plot - ), # Adjust width based on number of structures + width=max(1000, 150 * num_structures_plot), margin=dict(l=50, r=50, b=100, t=50, pad=4), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), ) @@ -686,11 +679,10 @@ def run_optimization_ase( if result_data["final_state"] is None or result_data["final_state"].energy is None: print(f"Plot2: Skipping energy diff for {name} as final_state or energy is None.") if name not in plot_names_fig2: - plot_names_fig2.append(name) # Keep name for consistent bar count - avg_energy_diffs_fig2.append(np.nan) # Add NaN if data missing + plot_names_fig2.append(name) + avg_energy_diffs_fig2.append(np.nan) continue - # Ensure name is added if not already by a skip if name not in plot_names_fig2: plot_names_fig2.append(name) @@ -744,17 +736,13 @@ def run_optimization_ase( n for n, v in zip(plot_names_fig2, avg_energy_diffs_fig2, strict=False) if not np.isnan(v) - ]: # Handle cases not covered + ]: print(f"Plot2: Fallback for {name}, setting energy diff to NaN.") avg_energy_diffs_fig2.append(np.nan) - -# Ensure plot_names_fig2 and avg_energy_diffs_fig2 have the same length -# This can happen if a name was added to plot_names_fig2 but its energy_diff calculation failed or was skipped. -# A more robust way is to build them in parallel. final_plot_names_fig2 = [] final_avg_energy_diffs_fig2 = [] -all_method_names_sorted = sorted(all_results) # Use a fixed order +all_method_names_sorted = sorted(all_results) for name in all_method_names_sorted: result_data = all_results[name] @@ -764,7 +752,7 @@ def run_optimization_ase( continue current_energies = result_data["final_state"].energy.cpu().numpy() - energy_to_append = np.nan # Default to NaN + energy_to_append = np.nan if name in (baseline_ase_pos_only, baseline_ase_frechet): energy_to_append = 0.0 @@ -808,7 +796,7 @@ def run_optimization_ase( f"{yval:.3f}" if not np.isnan(yval) else "" for yval in final_avg_energy_diffs_fig2 ], - textposition="auto", # Let Plotly decide best position, or use 'outside'/'inside' + textposition="auto", textfont=dict(size=10), ) @@ -824,12 +812,13 @@ def run_optimization_ase( shapes=[dict(type="line", y0=0, y1=0, x0=-0.5, x1=x1, line=line_dict)], height=600, width=max(800, 100 * len(final_plot_names_fig2)), - margin=dict(l=50, r=50, b=150, t=50, pad=4), # Increased bottom margin for labels + margin=dict(l=50, r=50, b=150, t=50, pad=4), ) # --- Plot 3: Mean Displacement from ASE Counterparts (Multi-bar per structure) --- -comparison_pairs_plot3_defs = [ # (ts_method_name, ase_method_name, short_label_for_legend) +# look at sets of: (ts_method_name, ase_method_name, short_label_for_legend) +comparison_pairs_plot3_defs = [ ( "torch-sim ASE-FIRE (PosOnly)", baseline_ase_pos_only, @@ -893,17 +882,15 @@ def run_optimization_ase( if len(mean_displacements_for_this_pair) == num_structures_plot: disp_data_fig3[:, pair_idx] = np.array(mean_displacements_for_this_pair) - else: # Should not happen if previous checks pass + else: print(f"Plot3: Inner loop displacement calculation mismatch for {plot_label}") else: print(f"Plot3: Missing data for methods in pair: {plot_label}.") - # Data remains NaN fig3_plotly = go.Figure() for idx, name in enumerate(legend_labels_fig3): - # x-axis is structure names fig3_plotly.add_bar(name=name, x=structure_names, y=disp_data_fig3[:, idx]) @@ -923,7 +910,6 @@ def run_optimization_ase( legend=dict(orientation="h", yanchor="bottom", y=0.96, xanchor="right", x=1), ) - # Show Plotly figures fig1_plotly.show() fig2_plotly.show()