Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 98 additions & 90 deletions scripts/swin_dataloader_compare_speed_with_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@
import numpy as np
import argparse

import oneflow as flow
from oneflow.utils.data import DataLoader

from flowvision import datasets, transforms
from flowvision.data import create_transform


ONEREC_URL = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/nanodataset.zip"
MD5 = "7f5cde8b5a6c411107517ac9b00f29db"

Expand Down Expand Up @@ -64,99 +57,104 @@ def ensure_dataset():
shutil.unpack_archive(absolute_file_path)
return str(pathlib.Path.cwd() / "nanodataset")


swin_dataloader_loop_count = 200


def print_rank_0(*args, **kwargs):
rank = int(os.getenv("RANK", "0"))
if rank == 0:
print(*args, **kwargs)


class SubsetRandomSampler(flow.utils.data.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""

def __init__(self, indices):
self.epoch = 0
self.indices = indices

def __iter__(self):
return (self.indices[i] for i in flow.randperm(len(self.indices)).tolist())

def __len__(self):
return len(self.indices)

def set_epoch(self, epoch):
self.epoch = epoch


def build_transform():
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=224,
is_training=True,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5-inc1",
re_prob=0.25,
re_mode="pixel",
re_count=1,
interpolation="bicubic",
)
return transform


# swin-transformer imagenet dataloader
def build_dataset(imagenet_path):
transform = build_transform()
prefix = "train"
root = os.path.join(imagenet_path, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset


def build_loader(imagenet_path, batch_size, num_wokers):
dataset_train = build_dataset(imagenet_path=imagenet_path)

indices = np.arange(
flow.env.get_rank(), len(dataset_train), flow.env.get_world_size()
)
sampler_train = SubsetRandomSampler(indices)

data_loader_train = DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=batch_size,
num_workers=num_wokers,
drop_last=True,
)

return dataset_train, data_loader_train


def run(mode, imagenet_path, batch_size, num_wokers):
if mode == "torch":
import torch as flow
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from timm.data import create_transform

dataset_train, data_loader_train = build_loader(
args.imagenet_path, args.batch_size, args.num_workers
)
data_loader_train_iter = iter(data_loader_train)

# warm up
for idx in range(5):
samples, targets = data_loader_train_iter.__next__()

start_time = time.time()
for idx in range(swin_dataloader_loop_count):
samples, targets = data_loader_train_iter.__next__()
total_time = time.time() - start_time
return total_time
else:
import oneflow as flow
from oneflow.utils.data import DataLoader

from flowvision import datasets, transforms
from flowvision.data import create_transform

class SubsetRandomSampler(flow.utils.data.Sampler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是可以去掉,在创建 DataLoader 的时候指定 shuffle=True 就可以了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""

def __init__(self, indices):
self.epoch = 0
self.indices = indices

def __iter__(self):
return (self.indices[i] for i in flow.randperm(len(self.indices)).tolist())

def __len__(self):
return len(self.indices)

def set_epoch(self, epoch):
self.epoch = epoch

def build_transform():
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=224,
is_training=True,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5-inc1",
re_prob=0.25,
re_mode="pixel",
re_count=1,
interpolation="bicubic",
)
return transform

# swin-transformer imagenet dataloader
def build_dataset(imagenet_path):
transform = build_transform()
prefix = "train"
root = os.path.join(imagenet_path, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset

def build_loader(imagenet_path, batch_size, num_wokers):
dataset_train = build_dataset(imagenet_path=imagenet_path)

indices = np.arange(0, len(dataset_train), 1)
sampler_train = SubsetRandomSampler(indices)

data_loader_train = DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=batch_size,
num_workers=num_wokers,
drop_last=True,
)

return dataset_train, data_loader_train

def get_time():
dataset_train, data_loader_train = build_loader(
args.imagenet_path, args.batch_size, args.num_workers
)
data_loader_train_iter = iter(data_loader_train)

# warm up
for idx in range(5):
samples, targets = data_loader_train_iter.__next__()

start_time = time.time()
for idx in range(swin_dataloader_loop_count):
samples, targets = data_loader_train_iter.__next__()
total_time = time.time() - start_time
return total_time

return get_time()


if __name__ == "__main__":
Expand All @@ -173,11 +171,21 @@ def run(mode, imagenet_path, batch_size, num_wokers):
pytorch_data_loader_total_time = run(
"torch", args.imagenet_path, args.batch_size, args.num_workers
)
oneflow_data_loader_time = oneflow_data_loader_total_time / swin_dataloader_loop_count
pytorch_data_loader_time = pytorch_data_loader_total_time / swin_dataloader_loop_count
oneflow_data_loader_time = (
oneflow_data_loader_total_time / swin_dataloader_loop_count
)
pytorch_data_loader_time = (
pytorch_data_loader_total_time / swin_dataloader_loop_count
)

relative_speed = oneflow_data_loader_time / pytorch_data_loader_time

print_rank_0(f"OneFlow swin dataloader time: {oneflow_data_loader_time:.3f}s (= {oneflow_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})")
print_rank_0(f"PyTorch swin dataloader time: {pytorch_data_loader_time:.3f}s (= {pytorch_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})")
print_rank_0(f"Relative speed: {pytorch_data_loader_time / oneflow_data_loader_time:.3f} (= {pytorch_data_loader_time:.3f}s / {oneflow_data_loader_time:.3f}s)")
print_rank_0(
f"OneFlow swin dataloader time: {oneflow_data_loader_time:.3f}s (= {oneflow_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})"
)
print_rank_0(
f"PyTorch swin dataloader time: {pytorch_data_loader_time:.3f}s (= {pytorch_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})"
)
print_rank_0(
f"Relative speed: {pytorch_data_loader_time / oneflow_data_loader_time:.3f} (= {pytorch_data_loader_time:.3f}s / {oneflow_data_loader_time:.3f}s)"
)