Skip to content

Commit 38e47c5

Browse files
authored
Fix swin dataloader import bug (#334)
* fix import bug * refine * code format * fix comment
1 parent dbbeb0f commit 38e47c5

File tree

1 file changed

+78
-90
lines changed

1 file changed

+78
-90
lines changed

scripts/swin_dataloader_compare_speed_with_pytorch.py

Lines changed: 78 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,6 @@
33
import numpy as np
44
import argparse
55

6-
import oneflow as flow
7-
from oneflow.utils.data import DataLoader
8-
9-
from flowvision import datasets, transforms
10-
from flowvision.data import create_transform
11-
12-
136
ONEREC_URL = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/nanodataset.zip"
147
MD5 = "7f5cde8b5a6c411107517ac9b00f29db"
158

@@ -64,99 +57,84 @@ def ensure_dataset():
6457
shutil.unpack_archive(absolute_file_path)
6558
return str(pathlib.Path.cwd() / "nanodataset")
6659

60+
6761
swin_dataloader_loop_count = 200
6862

63+
6964
def print_rank_0(*args, **kwargs):
7065
rank = int(os.getenv("RANK", "0"))
7166
if rank == 0:
7267
print(*args, **kwargs)
7368

7469

75-
class SubsetRandomSampler(flow.utils.data.Sampler):
76-
r"""Samples elements randomly from a given list of indices, without replacement.
77-
Arguments:
78-
indices (sequence): a sequence of indices
79-
"""
80-
81-
def __init__(self, indices):
82-
self.epoch = 0
83-
self.indices = indices
84-
85-
def __iter__(self):
86-
return (self.indices[i] for i in flow.randperm(len(self.indices)).tolist())
87-
88-
def __len__(self):
89-
return len(self.indices)
90-
91-
def set_epoch(self, epoch):
92-
self.epoch = epoch
93-
94-
95-
def build_transform():
96-
# this should always dispatch to transforms_imagenet_train
97-
transform = create_transform(
98-
input_size=224,
99-
is_training=True,
100-
color_jitter=0.4,
101-
auto_augment="rand-m9-mstd0.5-inc1",
102-
re_prob=0.25,
103-
re_mode="pixel",
104-
re_count=1,
105-
interpolation="bicubic",
106-
)
107-
return transform
108-
109-
110-
# swin-transformer imagenet dataloader
111-
def build_dataset(imagenet_path):
112-
transform = build_transform()
113-
prefix = "train"
114-
root = os.path.join(imagenet_path, prefix)
115-
dataset = datasets.ImageFolder(root, transform=transform)
116-
return dataset
117-
118-
119-
def build_loader(imagenet_path, batch_size, num_wokers):
120-
dataset_train = build_dataset(imagenet_path=imagenet_path)
121-
122-
indices = np.arange(
123-
flow.env.get_rank(), len(dataset_train), flow.env.get_world_size()
124-
)
125-
sampler_train = SubsetRandomSampler(indices)
126-
127-
data_loader_train = DataLoader(
128-
dataset_train,
129-
sampler=sampler_train,
130-
batch_size=batch_size,
131-
num_workers=num_wokers,
132-
drop_last=True,
133-
)
134-
135-
return dataset_train, data_loader_train
136-
137-
13870
def run(mode, imagenet_path, batch_size, num_wokers):
13971
if mode == "torch":
14072
import torch as flow
14173
from torch.utils.data import DataLoader
14274

14375
from torchvision import datasets, transforms
14476
from timm.data import create_transform
145-
146-
dataset_train, data_loader_train = build_loader(
147-
args.imagenet_path, args.batch_size, args.num_workers
148-
)
149-
data_loader_train_iter = iter(data_loader_train)
150-
151-
# warm up
152-
for idx in range(5):
153-
samples, targets = data_loader_train_iter.__next__()
154-
155-
start_time = time.time()
156-
for idx in range(swin_dataloader_loop_count):
157-
samples, targets = data_loader_train_iter.__next__()
158-
total_time = time.time() - start_time
159-
return total_time
77+
else:
78+
import oneflow as flow
79+
from oneflow.utils.data import DataLoader
80+
81+
from flowvision import datasets, transforms
82+
from flowvision.data import create_transform
83+
84+
def build_transform():
85+
# this should always dispatch to transforms_imagenet_train
86+
transform = create_transform(
87+
input_size=224,
88+
is_training=True,
89+
color_jitter=0.4,
90+
auto_augment="rand-m9-mstd0.5-inc1",
91+
re_prob=0.25,
92+
re_mode="pixel",
93+
re_count=1,
94+
interpolation="bicubic",
95+
)
96+
return transform
97+
98+
# swin-transformer imagenet dataloader
99+
def build_dataset(imagenet_path):
100+
transform = build_transform()
101+
prefix = "train"
102+
root = os.path.join(imagenet_path, prefix)
103+
dataset = datasets.ImageFolder(root, transform=transform)
104+
return dataset
105+
106+
def build_loader(imagenet_path, batch_size, num_wokers):
107+
dataset_train = build_dataset(imagenet_path=imagenet_path)
108+
109+
indices = np.arange(0, len(dataset_train), 1)
110+
111+
data_loader_train = DataLoader(
112+
dataset_train,
113+
shuffle=True,
114+
batch_size=batch_size,
115+
num_workers=num_wokers,
116+
drop_last=True,
117+
)
118+
119+
return dataset_train, data_loader_train
120+
121+
def get_time():
122+
dataset_train, data_loader_train = build_loader(
123+
args.imagenet_path, args.batch_size, args.num_workers
124+
)
125+
data_loader_train_iter = iter(data_loader_train)
126+
127+
# warm up
128+
for idx in range(5):
129+
samples, targets = data_loader_train_iter.__next__()
130+
131+
start_time = time.time()
132+
for idx in range(swin_dataloader_loop_count):
133+
samples, targets = data_loader_train_iter.__next__()
134+
total_time = time.time() - start_time
135+
return total_time
136+
137+
return get_time()
160138

161139

162140
if __name__ == "__main__":
@@ -173,11 +151,21 @@ def run(mode, imagenet_path, batch_size, num_wokers):
173151
pytorch_data_loader_total_time = run(
174152
"torch", args.imagenet_path, args.batch_size, args.num_workers
175153
)
176-
oneflow_data_loader_time = oneflow_data_loader_total_time / swin_dataloader_loop_count
177-
pytorch_data_loader_time = pytorch_data_loader_total_time / swin_dataloader_loop_count
154+
oneflow_data_loader_time = (
155+
oneflow_data_loader_total_time / swin_dataloader_loop_count
156+
)
157+
pytorch_data_loader_time = (
158+
pytorch_data_loader_total_time / swin_dataloader_loop_count
159+
)
178160

179161
relative_speed = oneflow_data_loader_time / pytorch_data_loader_time
180162

181-
print_rank_0(f"OneFlow swin dataloader time: {oneflow_data_loader_time:.3f}s (= {oneflow_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})")
182-
print_rank_0(f"PyTorch swin dataloader time: {pytorch_data_loader_time:.3f}s (= {pytorch_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})")
183-
print_rank_0(f"Relative speed: {pytorch_data_loader_time / oneflow_data_loader_time:.3f} (= {pytorch_data_loader_time:.3f}s / {oneflow_data_loader_time:.3f}s)")
163+
print_rank_0(
164+
f"OneFlow swin dataloader time: {oneflow_data_loader_time:.3f}s (= {oneflow_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})"
165+
)
166+
print_rank_0(
167+
f"PyTorch swin dataloader time: {pytorch_data_loader_time:.3f}s (= {pytorch_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})"
168+
)
169+
print_rank_0(
170+
f"Relative speed: {pytorch_data_loader_time / oneflow_data_loader_time:.3f} (= {pytorch_data_loader_time:.3f}s / {oneflow_data_loader_time:.3f}s)"
171+
)

0 commit comments

Comments
 (0)