1010import psutil
1111from torchmdnet .datasets import Custom , HDF5 , Ace
1212from torchmdnet .utils import write_as_hdf5
13+ from torch_geometric .loader import DataLoader
1314import h5py
1415import glob
1516
@@ -297,3 +298,36 @@ def test_ace(tmpdir):
297298 assert len (dataset_v2 ) == 8
298299 f2 .flush ()
299300 f2 .close ()
301+
302+
303+ @mark .parametrize ("num_files" , [1 , 3 ])
304+ @mark .parametrize ("tile_embed" , [True , False ])
305+ @mark .parametrize ("batch_size" , [1 , 5 ])
306+ def test_hdf5_with_and_without_caching (num_files , tile_embed , batch_size , tmpdir ):
307+ """This test ensures that the output from the get of the HDF5 dataset is the same
308+ when the dataset is loaded with and without caching."""
309+
310+ # set up necessary files
311+ _ = write_sample_npy_files (True , True , tmpdir , num_files )
312+ files = {}
313+ files ["pos" ] = sorted (glob .glob (join (tmpdir , "coords*" )))
314+ files ["z" ] = sorted (glob .glob (join (tmpdir , "embed*" )))
315+ files ["y" ] = sorted (glob .glob (join (tmpdir , "energy*" )))
316+ files ["neg_dy" ] = sorted (glob .glob (join (tmpdir , "forces*" )))
317+
318+ write_as_hdf5 (files , join (tmpdir , "test.hdf5" ), tile_embed )
319+ # Assert file is present in the disk
320+ assert os .path .isfile (join (tmpdir , "test.hdf5" )), "HDF5 file was not created"
321+
322+ data = HDF5 (join (tmpdir , "test.hdf5" ), dataset_preload_limit = 0 ) # no caching
323+ data_cached = HDF5 (join (tmpdir , "test.hdf5" ), dataset_preload_limit = 256 ) # caching
324+ assert len (data ) == len (data_cached ), "Number of samples does not match"
325+
326+ dl = DataLoader (data , batch_size )
327+ dl_cached = DataLoader (data_cached , batch_size )
328+
329+ for sample_cached , sample in zip (dl_cached , dl ):
330+ assert np .allclose (sample_cached .pos , sample .pos ), "Sample has incorrect coords"
331+ assert np .allclose (sample_cached .z , sample .z ), "Sample has incorrect atom numbers"
332+ assert np .allclose (sample_cached .y , sample .y ), "Sample has incorrect energy"
333+ assert np .allclose (sample_cached .neg_dy , sample .neg_dy ), "Sample has incorrect forces"
0 commit comments