Skip to content

Conversation

@AntonioMirarchi
Copy link
Contributor

This PR ensures consistency when using the HDF5 dataset class. Specifically, when self.cached=True, the _preload_data(self) function correctly handles cases where the embed 1D array is shared across all samples. When self.cached=False, the dataset instead processes tensors directly, assuming:

tensor_input = [[d[i]]] if d.ndim == 1 else d[i]

However, this assumption caused errors when embed was shared across samples, as it incorrectly indexed node i within the 1D embed array.

This PR fixes the issue and includes a pytest to verify consistency with and without caching.

@AntonioMirarchi
Copy link
Contributor Author

@stefdoerr @sef43 can you review this?

Only one consideration regarding the need to obtain a tensor with torch.Size([1, 1]) if d.ndim == 1: previously, [[d[i]]] was used, whereas now, to maintain consistency with the _preload_data() function, I return a tensor of torch.Size([1]) using [d[i]]. I think it's fine because module.py takes care of it here.

@stefdoerr
Copy link
Collaborator

Can't we change the dataloader to tile the atom types instead of duplicating information in the file (and wasting space on disk?)

@AntonioMirarchi
Copy link
Contributor Author

If you are referring to the write_as_hdf5() function that is used only in tests/test_datasets.py

@stefdoerr stefdoerr merged commit d616c8a into torchmd:main Feb 6, 2025
6 checks passed
@AntonioMirarchi AntonioMirarchi deleted the hdf5 branch February 6, 2025 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants