| 
5 | 5 | 
 
  | 
6 | 6 | 
 
  | 
7 | 7 | try:  | 
 | 8 | +    from collections.abc import Callable  | 
 | 9 | + | 
 | 10 | +    from ase.build import bulk, fcc100, molecule  | 
8 | 11 |     from huggingface_hub.utils._auth import get_token  | 
9 | 12 | 
 
  | 
 | 13 | +    import torch_sim as ts  | 
10 | 14 |     from torch_sim.models.fairchem import FairChemModel  | 
11 | 15 | 
 
  | 
12 | 16 | except ImportError:  | 
 | 
15 | 19 | 
 
  | 
16 | 20 | @pytest.fixture  | 
17 | 21 | def eqv2_uma_model_pbc(device: torch.device) -> FairChemModel:  | 
18 |  | -    """Use the UMA model which is available in fairchem-core-2.2.0+."""  | 
 | 22 | +    """UMA model for periodic boundary condition systems."""  | 
19 | 23 |     cpu = device.type == "cpu"  | 
20 | 24 |     return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)  | 
21 | 25 | 
 
  | 
22 | 26 | 
 
  | 
23 |  | -@pytest.fixture  | 
24 |  | -def eqv2_uma_model_non_pbc(device: torch.device) -> FairChemModel:  | 
25 |  | -    """Use the UMA model for non-PBC systems."""  | 
 | 27 | +# Removed calculator consistency tests since we're using predictor interface only  | 
 | 28 | + | 
 | 29 | + | 
 | 30 | +@pytest.mark.skipif(  | 
 | 31 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 32 | +)  | 
 | 33 | +@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"])  | 
 | 34 | +def test_task_initialization(task_name: str) -> None:  | 
 | 35 | +    """Test that different UMA task names work correctly."""  | 
 | 36 | +    model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True)  | 
 | 37 | +    assert model.task_name.value == task_name  | 
 | 38 | +    assert hasattr(model, "predictor")  | 
 | 39 | + | 
 | 40 | + | 
 | 41 | +@pytest.mark.skipif(  | 
 | 42 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 43 | +)  | 
 | 44 | +@pytest.mark.parametrize(  | 
 | 45 | +    ("task_name", "systems_func"),  | 
 | 46 | +    [  | 
 | 47 | +        (  | 
 | 48 | +            "omat",  | 
 | 49 | +            lambda: [  | 
 | 50 | +                bulk("Si", "diamond", a=5.43),  | 
 | 51 | +                bulk("Al", "fcc", a=4.05),  | 
 | 52 | +                bulk("Fe", "bcc", a=2.87),  | 
 | 53 | +                bulk("Cu", "fcc", a=3.61),  | 
 | 54 | +            ],  | 
 | 55 | +        ),  | 
 | 56 | +        (  | 
 | 57 | +            "omol",  | 
 | 58 | +            lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")],  | 
 | 59 | +        ),  | 
 | 60 | +    ],  | 
 | 61 | +)  | 
 | 62 | +def test_homogeneous_batching(  | 
 | 63 | +    task_name: str, systems_func: Callable, device: torch.device, dtype: torch.dtype  | 
 | 64 | +) -> None:  | 
 | 65 | +    """Test batching multiple systems with the same task."""  | 
 | 66 | +    systems = systems_func()  | 
 | 67 | + | 
 | 68 | +    # Add molecular properties for molecules  | 
 | 69 | +    if task_name == "omol":  | 
 | 70 | +        for mol in systems:  | 
 | 71 | +            mol.info.update({"charge": 0, "spin": 1})  | 
 | 72 | + | 
 | 73 | +    model = FairChemModel(  | 
 | 74 | +        model=None, model_name="uma-s-1", task_name=task_name, cpu=device.type == "cpu"  | 
 | 75 | +    )  | 
 | 76 | +    state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)  | 
 | 77 | +    results = model(state)  | 
 | 78 | + | 
 | 79 | +    # Check batch dimensions  | 
 | 80 | +    assert results["energy"].shape == (4,)  | 
 | 81 | +    assert results["forces"].shape[0] == sum(len(s) for s in systems)  | 
 | 82 | +    assert results["forces"].shape[1] == 3  | 
 | 83 | + | 
 | 84 | +    # Check that different systems have different energies  | 
 | 85 | +    energies = results["energy"]  | 
 | 86 | +    unique_energies = torch.unique(energies, dim=0)  | 
 | 87 | +    assert len(unique_energies) > 1, "Different systems should have different energies"  | 
 | 88 | + | 
 | 89 | + | 
 | 90 | +@pytest.mark.skipif(  | 
 | 91 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 92 | +)  | 
 | 93 | +def test_heterogeneous_tasks(device: torch.device, dtype: torch.dtype) -> None:  | 
 | 94 | +    """Test different task types work with appropriate systems."""  | 
 | 95 | +    # Test molecule, material, and catalysis systems separately  | 
 | 96 | +    test_cases = [  | 
 | 97 | +        ("omol", [molecule("H2O")]),  | 
 | 98 | +        ("omat", [bulk("Pt", cubic=True)]),  | 
 | 99 | +        ("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]),  | 
 | 100 | +    ]  | 
 | 101 | + | 
 | 102 | +    for task_name, systems in test_cases:  | 
 | 103 | +        if task_name == "omol":  | 
 | 104 | +            systems[0].info.update({"charge": 0, "spin": 1})  | 
 | 105 | + | 
 | 106 | +        model = FairChemModel(  | 
 | 107 | +            model=None,  | 
 | 108 | +            model_name="uma-s-1",  | 
 | 109 | +            task_name=task_name,  | 
 | 110 | +            cpu=device.type == "cpu",  | 
 | 111 | +        )  | 
 | 112 | +        state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)  | 
 | 113 | +        results = model(state)  | 
 | 114 | + | 
 | 115 | +        assert "energy" in results  | 
 | 116 | +        assert "forces" in results  | 
 | 117 | +        assert results["energy"].shape[0] == 1  | 
 | 118 | +        assert results["forces"].dim() == 2  | 
 | 119 | +        assert results["forces"].shape[1] == 3  | 
 | 120 | + | 
 | 121 | + | 
 | 122 | +@pytest.mark.skipif(  | 
 | 123 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 124 | +)  | 
 | 125 | +@pytest.mark.parametrize(  | 
 | 126 | +    ("systems_func", "expected_count"),  | 
 | 127 | +    [  | 
 | 128 | +        (lambda: [bulk("Si", "diamond", a=5.43)], 1),  # Single system  | 
 | 129 | +        (  | 
 | 130 | +            lambda: [  | 
 | 131 | +                bulk("H", "bcc", a=2.0),  | 
 | 132 | +                bulk("Li", "bcc", a=3.0),  | 
 | 133 | +                bulk("Si", "diamond", a=5.43),  | 
 | 134 | +                bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)),  | 
 | 135 | +            ],  | 
 | 136 | +            4,  | 
 | 137 | +        ),  # Mixed sizes  | 
 | 138 | +        (  | 
 | 139 | +            lambda: [  | 
 | 140 | +                bulk(element, "fcc", a=4.0)  | 
 | 141 | +                for element in ["Al", "Cu", "Ni", "Pd", "Pt"] * 3  | 
 | 142 | +            ],  | 
 | 143 | +            15,  | 
 | 144 | +        ),  # Large batch  | 
 | 145 | +    ],  | 
 | 146 | +)  | 
 | 147 | +def test_batch_size_variations(  | 
 | 148 | +    systems_func: Callable, expected_count: int, device: torch.device, dtype: torch.dtype  | 
 | 149 | +) -> None:  | 
 | 150 | +    """Test batching with different numbers and sizes of systems."""  | 
 | 151 | +    systems = systems_func()  | 
 | 152 | + | 
 | 153 | +    model = FairChemModel(  | 
 | 154 | +        model=None, model_name="uma-s-1", task_name="omat", cpu=device.type == "cpu"  | 
 | 155 | +    )  | 
 | 156 | +    state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)  | 
 | 157 | +    results = model(state)  | 
 | 158 | + | 
 | 159 | +    assert results["energy"].shape == (expected_count,)  | 
 | 160 | +    assert results["forces"].shape[0] == sum(len(s) for s in systems)  | 
 | 161 | +    assert results["forces"].shape[1] == 3  | 
 | 162 | +    assert torch.isfinite(results["energy"]).all()  | 
 | 163 | +    assert torch.isfinite(results["forces"]).all()  | 
 | 164 | + | 
 | 165 | + | 
 | 166 | +@pytest.mark.skipif(  | 
 | 167 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 168 | +)  | 
 | 169 | +@pytest.mark.parametrize("compute_stress", [True, False])  | 
 | 170 | +def test_stress_computation(  | 
 | 171 | +    *, compute_stress: bool, device: torch.device, dtype: torch.dtype  | 
 | 172 | +) -> None:  | 
 | 173 | +    """Test stress tensor computation."""  | 
 | 174 | +    systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)]  | 
 | 175 | + | 
 | 176 | +    model = FairChemModel(  | 
 | 177 | +        model=None,  | 
 | 178 | +        model_name="uma-s-1",  | 
 | 179 | +        task_name="omat",  | 
 | 180 | +        cpu=device.type == "cpu",  | 
 | 181 | +        compute_stress=compute_stress,  | 
 | 182 | +    )  | 
 | 183 | +    state = ts.io.atoms_to_state(systems, device=device, dtype=dtype)  | 
 | 184 | +    results = model(state)  | 
 | 185 | + | 
 | 186 | +    if compute_stress:  | 
 | 187 | +        assert "stress" in results  | 
 | 188 | +        assert results["stress"].shape == (2, 3, 3)  | 
 | 189 | +        assert torch.isfinite(results["stress"]).all()  | 
 | 190 | +    else:  | 
 | 191 | +        assert "stress" not in results  | 
 | 192 | + | 
 | 193 | + | 
 | 194 | +@pytest.mark.skipif(  | 
 | 195 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 196 | +)  | 
 | 197 | +def test_device_consistency(dtype: torch.dtype) -> None:  | 
 | 198 | +    """Test device consistency between model and data."""  | 
 | 199 | +    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  | 
26 | 200 |     cpu = device.type == "cpu"  | 
27 |  | -    return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)  | 
28 | 201 | 
 
  | 
 | 202 | +    model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)  | 
 | 203 | +    system = bulk("Si", "diamond", a=5.43)  | 
 | 204 | +    state = ts.io.atoms_to_state([system], device=device, dtype=dtype)  | 
29 | 205 | 
 
  | 
30 |  | -# Removed calculator consistency tests since we're using predictor interface only  | 
 | 206 | +    results = model(state)  | 
 | 207 | +    assert results["energy"].device == device  | 
 | 208 | +    assert results["forces"].device == device  | 
 | 209 | + | 
 | 210 | + | 
 | 211 | +@pytest.mark.skipif(  | 
 | 212 | +    get_token() is None, reason="Requires HuggingFace authentication for UMA model access"  | 
 | 213 | +)  | 
 | 214 | +def test_empty_batch_error() -> None:  | 
 | 215 | +    """Test that empty batches raise appropriate errors."""  | 
 | 216 | +    model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True)  | 
 | 217 | +    with pytest.raises((ValueError, RuntimeError, IndexError)):  | 
 | 218 | +        model(ts.io.atoms_to_state([], device="cpu", dtype=torch.float32))  | 
31 | 219 | 
 
  | 
32 | 220 | 
 
  | 
33 | 221 | test_fairchem_uma_model_outputs = pytest.mark.skipif(  | 
 | 
0 commit comments