diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 22b00249d..8cdf0d1ae 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -10,6 +10,7 @@ import psutil from torchmdnet.datasets import Custom, HDF5, Ace from torchmdnet.utils import write_as_hdf5 +from torch_geometric.loader import DataLoader import h5py import glob @@ -297,3 +298,36 @@ def test_ace(tmpdir): assert len(dataset_v2) == 8 f2.flush() f2.close() + + +@mark.parametrize("num_files", [1, 3]) +@mark.parametrize("tile_embed", [True, False]) +@mark.parametrize("batch_size", [1, 5]) +def test_hdf5_with_and_without_caching(num_files, tile_embed, batch_size, tmpdir): + """This test ensures that the output from the get of the HDF5 dataset is the same + when the dataset is loaded with and without caching.""" + + # set up necessary files + _ = write_sample_npy_files(True, True, tmpdir, num_files) + files = {} + files["pos"] = sorted(glob.glob(join(tmpdir, "coords*"))) + files["z"] = sorted(glob.glob(join(tmpdir, "embed*"))) + files["y"] = sorted(glob.glob(join(tmpdir, "energy*"))) + files["neg_dy"] = sorted(glob.glob(join(tmpdir, "forces*"))) + + write_as_hdf5(files, join(tmpdir, "test.hdf5"), tile_embed) + # Assert file is present in the disk + assert os.path.isfile(join(tmpdir, "test.hdf5")), "HDF5 file was not created" + + data = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=0) # no caching + data_cached = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=256) # caching + assert len(data) == len(data_cached), "Number of samples does not match" + + dl = DataLoader(data, batch_size) + dl_cached = DataLoader(data_cached, batch_size) + + for sample_cached, sample in zip(dl_cached, dl): + assert np.allclose(sample_cached.pos, sample.pos), "Sample has incorrect coords" + assert np.allclose(sample_cached.z, sample.z), "Sample has incorrect atom numbers" + assert np.allclose(sample_cached.y, sample.y), "Sample has incorrect energy" + assert np.allclose(sample_cached.neg_dy, sample.neg_dy), "Sample has incorrect forces" \ No newline at end of file diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index c647b7d2a..179b637e4 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -125,8 +125,13 @@ def get(self, idx): if self.index is None: self._setup_index() *fields_data, i = self.index[idx] + # Assuming the first element of fields_data is 'pos' based on the definition of self.fields + size = len(fields_data[0]) for (name, _, dtype), d in zip(self.fields, fields_data): - tensor_input = [[d[i]]] if d.ndim == 1 else d[i] + if d.ndim == 1: + tensor_input = [d[i]] if len(d) == size else d[:] + else: + tensor_input = d[i] data[name] = torch.tensor(tensor_input, dtype=dtype) return data diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index f97888556..1a1f7cd5a 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -346,12 +346,13 @@ class MissingEnergyException(Exception): pass -def write_as_hdf5(files, hdf5_dataset): +def write_as_hdf5(files, hdf5_dataset, tile_embed=True): """Transform the input numpy files to hdf5 format compatible with the HDF5 Dataset class. The input files to this function are the same as the ones required by the Custom dataset. Args: files (dict): Dictionary of numpy input files. Must contain "pos", "z" and at least one of "y" or "neg_dy". hdf5_dataset (string): Path to the output HDF5 dataset. + tile_embed (bool): Whether to tile the embeddings to match the number of samples. Default: True Example: >>> files = {} >>> files["pos"] = sorted(glob.glob(join(tmpdir, "coords*"))) @@ -370,7 +371,10 @@ def write_as_hdf5(files, hdf5_dataset): group = f.create_group(str(i)) num_samples = coord_data.shape[0] group.create_dataset("pos", data=coord_data) - group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1))) + if tile_embed: + group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1))) + else: + group.create_dataset("types", data=embed_data) if "y" in files: energy_data = np.load(files["y"][i], mmap_mode="r") group.create_dataset("energy", data=energy_data)