diff --git a/examples/tutorials/metatensor_tutorial.py b/examples/tutorials/metatensor_tutorial.py index d2afe47e..944d4348 100644 --- a/examples/tutorials/metatensor_tutorial.py +++ b/examples/tutorials/metatensor_tutorial.py @@ -3,8 +3,8 @@ # Dependencies # /// script # dependencies = [ -# "metatrain[pet] >=2025.4", -# "metatensor-torch >=0.7,<0.8" +# "metatrain[pet]>=2025.4", +# "metatensor-torch>=0.7,<0.8" # ] # /// # diff --git a/pyproject.toml b/pyproject.toml index d7f6eaba..aa8211be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ test = [ ] mace = ["mace-torch>=0.3.11"] mattersim = ["mattersim>=0.1.2"] -metatensor = ["metatensor-torch >=0.7,<0.8", "metatrain[pet] >=2025.4"] +metatensor = ["metatensor-torch>=0.7,<0.8", "metatrain[pet]>=2025.4"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.0.34", "mace-torch>=0.3.11"] diff --git a/torch_sim/models/metatensor.py b/torch_sim/models/metatensor.py index 18065cc0..63d468a2 100644 --- a/torch_sim/models/metatensor.py +++ b/torch_sim/models/metatensor.py @@ -223,9 +223,14 @@ def forward( # noqa: C901, PLR0915 ) # Calculate the required neighbor list(s) for all the systems + + # move data to CPU because vesin only supports CPU for now + systems = [system.to(device="cpu") for system in systems] vesin.torch.metatensor.compute_requested_neighbors( systems, system_length_unit="Angstrom", model=self._model ) + # move back to the proper device + systems = [system.to(device=self.device) for system in systems] # Get model output model_outputs = self._model(