Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import psutil
from torchmdnet.datasets import Custom, HDF5, Ace
from torchmdnet.utils import write_as_hdf5
from torch_geometric.loader import DataLoader
import h5py
import glob

Expand Down Expand Up @@ -297,3 +298,36 @@ def test_ace(tmpdir):
assert len(dataset_v2) == 8
f2.flush()
f2.close()


@mark.parametrize("num_files", [1, 3])
@mark.parametrize("tile_embed", [True, False])
@mark.parametrize("batch_size", [1, 5])
def test_hdf5_with_and_without_caching(num_files, tile_embed, batch_size, tmpdir):
"""This test ensures that the output from the get of the HDF5 dataset is the same
when the dataset is loaded with and without caching."""

# set up necessary files
_ = write_sample_npy_files(True, True, tmpdir, num_files)
files = {}
files["pos"] = sorted(glob.glob(join(tmpdir, "coords*")))
files["z"] = sorted(glob.glob(join(tmpdir, "embed*")))
files["y"] = sorted(glob.glob(join(tmpdir, "energy*")))
files["neg_dy"] = sorted(glob.glob(join(tmpdir, "forces*")))

write_as_hdf5(files, join(tmpdir, "test.hdf5"), tile_embed)
# Assert file is present in the disk
assert os.path.isfile(join(tmpdir, "test.hdf5")), "HDF5 file was not created"

data = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=0) # no caching
data_cached = HDF5(join(tmpdir, "test.hdf5"), dataset_preload_limit=256) # caching
assert len(data) == len(data_cached), "Number of samples does not match"

dl = DataLoader(data, batch_size)
dl_cached = DataLoader(data_cached, batch_size)

for sample_cached, sample in zip(dl_cached, dl):
assert np.allclose(sample_cached.pos, sample.pos), "Sample has incorrect coords"
assert np.allclose(sample_cached.z, sample.z), "Sample has incorrect atom numbers"
assert np.allclose(sample_cached.y, sample.y), "Sample has incorrect energy"
assert np.allclose(sample_cached.neg_dy, sample.neg_dy), "Sample has incorrect forces"
7 changes: 6 additions & 1 deletion torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,13 @@ def get(self, idx):
if self.index is None:
self._setup_index()
*fields_data, i = self.index[idx]
# Assuming the first element of fields_data is 'pos' based on the definition of self.fields
size = len(fields_data[0])
for (name, _, dtype), d in zip(self.fields, fields_data):
tensor_input = [[d[i]]] if d.ndim == 1 else d[i]
if d.ndim == 1:
tensor_input = [d[i]] if len(d) == size else d[:]
else:
tensor_input = d[i]
data[name] = torch.tensor(tensor_input, dtype=dtype)
return data

Expand Down
8 changes: 6 additions & 2 deletions torchmdnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,13 @@ class MissingEnergyException(Exception):
pass


def write_as_hdf5(files, hdf5_dataset):
def write_as_hdf5(files, hdf5_dataset, tile_embed=True):
"""Transform the input numpy files to hdf5 format compatible with the HDF5 Dataset class.
The input files to this function are the same as the ones required by the Custom dataset.
Args:
files (dict): Dictionary of numpy input files. Must contain "pos", "z" and at least one of "y" or "neg_dy".
hdf5_dataset (string): Path to the output HDF5 dataset.
tile_embed (bool): Whether to tile the embeddings to match the number of samples. Default: True
Example:
>>> files = {}
>>> files["pos"] = sorted(glob.glob(join(tmpdir, "coords*")))
Expand All @@ -370,7 +371,10 @@ def write_as_hdf5(files, hdf5_dataset):
group = f.create_group(str(i))
num_samples = coord_data.shape[0]
group.create_dataset("pos", data=coord_data)
group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1)))
if tile_embed:
group.create_dataset("types", data=np.tile(embed_data, (num_samples, 1)))
else:
group.create_dataset("types", data=embed_data)
if "y" in files:
energy_data = np.load(files["y"][i], mmap_mode="r")
group.create_dataset("energy", data=energy_data)
Expand Down
Loading