Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies = [
"torch_geometric",
"lightning",
"tqdm",
"pandas",
]

[project.urls]
Expand Down
12 changes: 2 additions & 10 deletions torchmdnet/datasets/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchmdnet.datasets.memdataset import MemmappedDataset
from torch_geometric.data import Data
from tqdm import tqdm
import pandas as pd


class Ace(MemmappedDataset):
Expand Down Expand Up @@ -133,7 +132,6 @@ def __init__(
paths=None,
max_gradient=None,
subsample_molecules=1,
index_csv=None,
):
assert isinstance(paths, (str, list))

Expand All @@ -143,13 +141,8 @@ def __init__(
self.paths = paths
self.max_gradient = max_gradient
self.subsample_molecules = int(subsample_molecules)
if index_csv is not None:
df = pd.read_csv(index_csv, dtype=int, converters={"name": str})
self.mol_indexes = {mol_id: i for i, mol_id in enumerate(df.name)}

props = ["y", "neg_dy", "q", "pq", "dp"]
if index_csv is not None:
props += ["mol_idx"]
super().__init__(
root,
transform,
Expand Down Expand Up @@ -239,7 +232,7 @@ def _load_confs_2_0(mol, n_atoms):
def sample_iter(self, mol_ids=False):
assert self.subsample_molecules > 0

for path in tqdm(self.raw_paths, desc="Files"):
for i_path, path in tqdm(enumerate(self.raw_paths), desc="Files"):
h5 = h5py.File(path)
assert h5.attrs["layout"] == "Ace"
version = h5.attrs["layout_version"]
Expand Down Expand Up @@ -285,10 +278,9 @@ def sample_iter(self, mol_ids=False):
z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp
)
if mol_ids:
args["i_path"] = i_path
args["mol_id"] = mol_id
args["i_conf"] = i_conf
if "mol_idx" in self.properties:
args["mol_idx"] = self.mol_indexes[mol_id]

data = Data(**args)

Expand Down
18 changes: 0 additions & 18 deletions torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ def __init__(
self.mmaps["dp"] = np.memmap(
fnames["dp"], mode="r", dtype=np.float32, shape=(num_all_confs, 3)
)
if "mol_idx" in self.properties:
self.mmaps["mol_idx"] = np.memmap(
fnames["mol_idx"], mode="r", dtype=np.uint64
)

assert self.mmaps["idx"][0] == 0
assert self.mmaps["idx"][-1] == len(self.mmaps["z"])
Expand Down Expand Up @@ -178,13 +174,6 @@ def process(self):
dtype=np.float32,
shape=(num_all_confs, 3),
)
if "mol_idx" in self.properties:
mmaps["mol_idx"] = np.memmap(
fnames["mol_idx"] + ".tmp",
mode="w+",
dtype=np.uint64,
shape=(num_all_confs,),
)

print("Storing data...")
i_atom = 0
Expand All @@ -204,8 +193,6 @@ def process(self):
mmaps["pq"][i_atom:i_next_atom] = data.pq
if "dp" in self.properties:
mmaps["dp"][i_conf] = data.dp
if "mol_idx" in self.properties:
mmaps["mol_idx"][i_conf] = data.mol_idx
i_atom = i_next_atom

mmaps["idx"][-1] = num_all_atoms
Expand All @@ -231,8 +218,6 @@ def process(self):
os.rename(fnames["pq"] + ".tmp", fnames["pq"])
if "dp" in self.properties:
os.rename(fnames["dp"] + ".tmp", fnames["dp"])
if "mol_idx" in self.properties:
os.rename(fnames["mol_idx"] + ".tmp", fnames["mol_idx"])

def len(self):
return len(self.mmaps["idx"]) - 1
Expand All @@ -249,7 +234,6 @@ def get(self, idx):
- :obj:`q`: Total charge of the molecule.
- :obj:`pq`: Partial charges of the atoms.
- :obj:`dp`: Dipole moment of the molecule.
- :obj:`mol_idx`: The index of the molecule of the conformer.

Args:
idx (int): Index of the data object.
Expand All @@ -272,8 +256,6 @@ def get(self, idx):
props["pq"] = pt.tensor(self.mmaps["pq"][atoms])
if "dp" in self.properties:
props["dp"] = pt.tensor(self.mmaps["dp"][idx])
# if "mol_idx" in self.properties:
# props["mol_idx"] = pt.tensor(self.mmaps["mol_idx"][idx], dtype=pt.int64).view(1, 1)
return Data(z=z, pos=pos, **props)

def __del__(self):
Expand Down
Loading