1717
1818from __future__ import annotations
1919
20+ from collections .abc import Mapping
2021import os
2122
2223import numpy as np
@@ -44,13 +45,27 @@ def _info(self):
4445 homepage = 'https://github.com/deepmind/pg19' ,
4546 )
4647
48+ def _get_paths (self , data_dir : str ) -> Mapping [str , str ]:
49+ return {
50+ 'metadata' : os .path .join (data_dir , 'metadata.csv' ),
51+ 'train' : os .path .join (data_dir , 'train' ),
52+ 'validation' : os .path .join (data_dir , 'validation' ),
53+ 'test' : os .path .join (data_dir , 'test' ),
54+ }
55+
4756 def _split_generators (self , dl_manager ):
4857 """Returns SplitGenerators."""
4958 del dl_manager # Unused
5059
5160 metadata_dict = dict ()
52- metadata_path = os .path .join (_DATA_DIR , 'metadata.csv' )
53- metadata = tf .io .gfile .GFile (metadata_path ).read ().splitlines ()
61+ if self .data_dir and all (
62+ map (os .path .exists , self ._get_paths (self .data_dir ).values ())
63+ ):
64+ data_dir = self ._data_dir
65+ else :
66+ data_dir = _DATA_DIR
67+ paths = self ._get_paths (data_dir )
68+ metadata = tf .io .gfile .GFile (paths ['metadata' ]).read ().splitlines ()
5469
5570 for row in metadata :
5671 row_split = row .split (',' )
@@ -62,21 +77,21 @@ def _split_generators(self, dl_manager):
6277 name = tfds .Split .TRAIN ,
6378 gen_kwargs = {
6479 'metadata' : metadata_dict ,
65- 'filepath' : os . path . join ( _DATA_DIR , 'train' ) ,
80+ 'filepath' : paths [ 'train' ] ,
6681 },
6782 ),
6883 tfds .core .SplitGenerator (
6984 name = tfds .Split .VALIDATION ,
7085 gen_kwargs = {
7186 'metadata' : metadata_dict ,
72- 'filepath' : os . path . join ( _DATA_DIR , 'validation' ) ,
87+ 'filepath' : paths [ 'validation' ] ,
7388 },
7489 ),
7590 tfds .core .SplitGenerator (
7691 name = tfds .Split .TEST ,
7792 gen_kwargs = {
7893 'metadata' : metadata_dict ,
79- 'filepath' : os . path . join ( _DATA_DIR , 'test' ) ,
94+ 'filepath' : paths [ 'test' ] ,
8095 },
8196 ),
8297 ]
0 commit comments