33import numpy as np
44import argparse
55
6- import oneflow as flow
7- from oneflow .utils .data import DataLoader
8-
9- from flowvision import datasets , transforms
10- from flowvision .data import create_transform
11-
12-
136ONEREC_URL = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/nanodataset.zip"
147MD5 = "7f5cde8b5a6c411107517ac9b00f29db"
158
@@ -64,99 +57,84 @@ def ensure_dataset():
6457 shutil .unpack_archive (absolute_file_path )
6558 return str (pathlib .Path .cwd () / "nanodataset" )
6659
60+
6761swin_dataloader_loop_count = 200
6862
63+
6964def print_rank_0 (* args , ** kwargs ):
7065 rank = int (os .getenv ("RANK" , "0" ))
7166 if rank == 0 :
7267 print (* args , ** kwargs )
7368
7469
75- class SubsetRandomSampler (flow .utils .data .Sampler ):
76- r"""Samples elements randomly from a given list of indices, without replacement.
77- Arguments:
78- indices (sequence): a sequence of indices
79- """
80-
81- def __init__ (self , indices ):
82- self .epoch = 0
83- self .indices = indices
84-
85- def __iter__ (self ):
86- return (self .indices [i ] for i in flow .randperm (len (self .indices )).tolist ())
87-
88- def __len__ (self ):
89- return len (self .indices )
90-
91- def set_epoch (self , epoch ):
92- self .epoch = epoch
93-
94-
95- def build_transform ():
96- # this should always dispatch to transforms_imagenet_train
97- transform = create_transform (
98- input_size = 224 ,
99- is_training = True ,
100- color_jitter = 0.4 ,
101- auto_augment = "rand-m9-mstd0.5-inc1" ,
102- re_prob = 0.25 ,
103- re_mode = "pixel" ,
104- re_count = 1 ,
105- interpolation = "bicubic" ,
106- )
107- return transform
108-
109-
110- # swin-transformer imagenet dataloader
111- def build_dataset (imagenet_path ):
112- transform = build_transform ()
113- prefix = "train"
114- root = os .path .join (imagenet_path , prefix )
115- dataset = datasets .ImageFolder (root , transform = transform )
116- return dataset
117-
118-
119- def build_loader (imagenet_path , batch_size , num_wokers ):
120- dataset_train = build_dataset (imagenet_path = imagenet_path )
121-
122- indices = np .arange (
123- flow .env .get_rank (), len (dataset_train ), flow .env .get_world_size ()
124- )
125- sampler_train = SubsetRandomSampler (indices )
126-
127- data_loader_train = DataLoader (
128- dataset_train ,
129- sampler = sampler_train ,
130- batch_size = batch_size ,
131- num_workers = num_wokers ,
132- drop_last = True ,
133- )
134-
135- return dataset_train , data_loader_train
136-
137-
13870def run (mode , imagenet_path , batch_size , num_wokers ):
13971 if mode == "torch" :
14072 import torch as flow
14173 from torch .utils .data import DataLoader
14274
14375 from torchvision import datasets , transforms
14476 from timm .data import create_transform
145-
146- dataset_train , data_loader_train = build_loader (
147- args .imagenet_path , args .batch_size , args .num_workers
148- )
149- data_loader_train_iter = iter (data_loader_train )
150-
151- # warm up
152- for idx in range (5 ):
153- samples , targets = data_loader_train_iter .__next__ ()
154-
155- start_time = time .time ()
156- for idx in range (swin_dataloader_loop_count ):
157- samples , targets = data_loader_train_iter .__next__ ()
158- total_time = time .time () - start_time
159- return total_time
77+ else :
78+ import oneflow as flow
79+ from oneflow .utils .data import DataLoader
80+
81+ from flowvision import datasets , transforms
82+ from flowvision .data import create_transform
83+
84+ def build_transform ():
85+ # this should always dispatch to transforms_imagenet_train
86+ transform = create_transform (
87+ input_size = 224 ,
88+ is_training = True ,
89+ color_jitter = 0.4 ,
90+ auto_augment = "rand-m9-mstd0.5-inc1" ,
91+ re_prob = 0.25 ,
92+ re_mode = "pixel" ,
93+ re_count = 1 ,
94+ interpolation = "bicubic" ,
95+ )
96+ return transform
97+
98+ # swin-transformer imagenet dataloader
99+ def build_dataset (imagenet_path ):
100+ transform = build_transform ()
101+ prefix = "train"
102+ root = os .path .join (imagenet_path , prefix )
103+ dataset = datasets .ImageFolder (root , transform = transform )
104+ return dataset
105+
106+ def build_loader (imagenet_path , batch_size , num_wokers ):
107+ dataset_train = build_dataset (imagenet_path = imagenet_path )
108+
109+ indices = np .arange (0 , len (dataset_train ), 1 )
110+
111+ data_loader_train = DataLoader (
112+ dataset_train ,
113+ shuffle = True ,
114+ batch_size = batch_size ,
115+ num_workers = num_wokers ,
116+ drop_last = True ,
117+ )
118+
119+ return dataset_train , data_loader_train
120+
121+ def get_time ():
122+ dataset_train , data_loader_train = build_loader (
123+ args .imagenet_path , args .batch_size , args .num_workers
124+ )
125+ data_loader_train_iter = iter (data_loader_train )
126+
127+ # warm up
128+ for idx in range (5 ):
129+ samples , targets = data_loader_train_iter .__next__ ()
130+
131+ start_time = time .time ()
132+ for idx in range (swin_dataloader_loop_count ):
133+ samples , targets = data_loader_train_iter .__next__ ()
134+ total_time = time .time () - start_time
135+ return total_time
136+
137+ return get_time ()
160138
161139
162140if __name__ == "__main__" :
@@ -173,11 +151,21 @@ def run(mode, imagenet_path, batch_size, num_wokers):
173151 pytorch_data_loader_total_time = run (
174152 "torch" , args .imagenet_path , args .batch_size , args .num_workers
175153 )
176- oneflow_data_loader_time = oneflow_data_loader_total_time / swin_dataloader_loop_count
177- pytorch_data_loader_time = pytorch_data_loader_total_time / swin_dataloader_loop_count
154+ oneflow_data_loader_time = (
155+ oneflow_data_loader_total_time / swin_dataloader_loop_count
156+ )
157+ pytorch_data_loader_time = (
158+ pytorch_data_loader_total_time / swin_dataloader_loop_count
159+ )
178160
179161 relative_speed = oneflow_data_loader_time / pytorch_data_loader_time
180162
181- print_rank_0 (f"OneFlow swin dataloader time: { oneflow_data_loader_time :.3f} s (= { oneflow_data_loader_total_time :.3f} s / { swin_dataloader_loop_count } , num_workers={ args .num_workers } )" )
182- print_rank_0 (f"PyTorch swin dataloader time: { pytorch_data_loader_time :.3f} s (= { pytorch_data_loader_total_time :.3f} s / { swin_dataloader_loop_count } , num_workers={ args .num_workers } )" )
183- print_rank_0 (f"Relative speed: { pytorch_data_loader_time / oneflow_data_loader_time :.3f} (= { pytorch_data_loader_time :.3f} s / { oneflow_data_loader_time :.3f} s)" )
163+ print_rank_0 (
164+ f"OneFlow swin dataloader time: { oneflow_data_loader_time :.3f} s (= { oneflow_data_loader_total_time :.3f} s / { swin_dataloader_loop_count } , num_workers={ args .num_workers } )"
165+ )
166+ print_rank_0 (
167+ f"PyTorch swin dataloader time: { pytorch_data_loader_time :.3f} s (= { pytorch_data_loader_total_time :.3f} s / { swin_dataloader_loop_count } , num_workers={ args .num_workers } )"
168+ )
169+ print_rank_0 (
170+ f"Relative speed: { pytorch_data_loader_time / oneflow_data_loader_time :.3f} (= { pytorch_data_loader_time :.3f} s / { oneflow_data_loader_time :.3f} s)"
171+ )
0 commit comments