@@ -33,7 +33,10 @@ def __init__(
3333 transform = None ,
3434 pre_transform = None ,
3535 pre_filter = None ,
36+ source_file = "mdcath_source.h5" ,
37+ file_basename = "mdcath_dataset" ,
3638 numAtoms = 5000 ,
39+ numNoHAtoms = None ,
3740 numResidues = 1000 ,
3841 temperatures = ["348" ],
3942 skip_frames = 1 ,
@@ -52,16 +55,21 @@ def __init__(
5255 Root directory where the dataset should be stored. Data will be downloaded to 'root/'.
5356 numAtoms: int
5457 Max number of atoms in the protein structure.
58+ source_file: str
59+ Name of the source file with the information about the protein structures. Default is "mdcath_source.h5".
60+ file_basename: str
61+ Base name of the hdf5 files. Default is "mdcath_dataset".
5562 numNoHAtoms: int
56- Max number of non-hydrogen atoms in the protein structure.
63+ Max number of non-hydrogen atoms in the protein structure, not available for original mdcath dataset. Default is None.
64+ Be sure to have the attribute 'numNoHAtoms' in the source file.
5765 numResidues: int
5866 Max number of residues in the protein structure.
5967 temperatures: list
6068 List of temperatures (in Kelvin) to download. Default is ["348"]. Available temperatures are ['320', '348', '379', '413', '450']
6169 skip_frames: int
6270 Number of frames to skip in the trajectory. Default is 1.
6371 pdb_list: list or str
64- List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'mdcath_source.h5 ' will be downloaded.
72+ List of PDB IDs to download or path to a file with the PDB IDs. If None, all available PDB IDs from 'source_file ' will be downloaded.
6573 The filters will be applied to the PDB IDs in this list in any case. Default is None.
6674 min_gyration_radius: float
6775 Minimum gyration radius (in nm) of the protein structure. Default is None.
@@ -76,7 +84,9 @@ def __init__(
7684 """
7785
7886 self .url = "https://huggingface.co/datasets/compsciencelab/mdCATH/resolve/main/"
79- self .source_file = "mdcath_source.h5"
87+ self .source_file = source_file
88+ self .file_basename = file_basename
89+ self .numNoHAtoms = numNoHAtoms
8090 self .root = root
8191 os .makedirs (root , exist_ok = True )
8292 self .numAtoms = numAtoms
@@ -103,33 +113,35 @@ def __init__(
103113
104114 @property
105115 def raw_file_names (self ):
106- return [f"mdcath_dataset_ { pdb_id } .h5" for pdb_id in self .processed .keys ()]
116+ return [f"{ self . file_basename } _ { pdb_id } .h5" for pdb_id in self .processed .keys ()]
107117
108118 @property
109119 def raw_dir (self ):
110120 # Override the raw_dir property to return the root directory
111- # The files will be downloaded to the root directory
121+ # The files will be downloaded to the root directory, compatible only with original mdcath dataset
112122 return self .root
113123
114124 def _ensure_source_file (self ):
115125 """Ensure the source file is downloaded before processing."""
116126 source_path = os .path .join (self .root , self .source_file )
117127 if not os .path .exists (source_path ):
128+ assert self .source_file == "mdcath_source.h5" , "Only 'mdcath_source.h5' is supported as source file for download."
118129 logger .info (f"Downloading source file { self .source_file } " )
119130 urllib .request .urlretrieve (opj (self .url , self .source_file ), source_path )
120131
121132 def download (self ):
122133 for pdb_id in self .processed .keys ():
123- file_name = f"mdcath_dataset_ { pdb_id } .h5"
134+ file_name = f"{ self . file_basename } _ { pdb_id } .h5"
124135 file_path = opj (self .raw_dir , file_name )
125136 if not os .path .exists (file_path ):
137+ assert self .file_basename == "mdcath_dataset" , "Only 'mdcath_dataset' is supported as file_basename for download."
126138 # Download the file if it does not exist
127139 urllib .request .urlretrieve (opj (self .url , 'data' , file_name ), file_path )
128140
129141 def calculate_dataset_size (self ):
130142 total_size_bytes = 0
131143 for pdb_id in self .processed .keys ():
132- file_name = f"mdcath_dataset_ { pdb_id } .h5"
144+ file_name = f"{ self . file_basename } _ { pdb_id } .h5"
133145 total_size_bytes += os .path .getsize (opj (self .root , file_name ))
134146 total_size_mb = round (total_size_bytes / (1024 * 1024 ), 4 )
135147 return total_size_mb
@@ -161,7 +173,8 @@ def _evaluate_replica(self, pdb_id, temp, replica, pdb_group):
161173 self .numFrames is not None and pdb_group [temp ][replica ].attrs ["numFrames" ] < self .numFrames ,
162174 self .min_gyration_radius is not None and pdb_group [temp ][replica ].attrs ["min_gyration_radius" ] < self .min_gyration_radius ,
163175 self .max_gyration_radius is not None and pdb_group [temp ][replica ].attrs ["max_gyration_radius" ] > self .max_gyration_radius ,
164- self ._evaluate_structure (pdb_group , temp , replica )
176+ self ._evaluate_structure (pdb_group , temp , replica ),
177+ self .numNoHAtoms is not None and pdb_group .attrs ["numNoHAtoms" ] > self .numNoHAtoms ,
165178 ]
166179 if any (conditions ):
167180 return
@@ -180,7 +193,7 @@ def len(self):
180193 return self .num_conformers
181194
182195 def _setup_idx (self ):
183- files = [opj (self .root , f"mdcath_dataset_ { pdb_id } .h5" ) for pdb_id in self .processed .keys ()]
196+ files = [opj (self .root , f"{ self . file_basename } _ { pdb_id } .h5" ) for pdb_id in self .processed .keys ()]
184197 self .idx = []
185198 for i , (pdb , group_info ) in enumerate (self .processed .items ()):
186199 for temp , replica , num_frames in group_info :
0 commit comments