This repo implements the SimCLR algorithm on Vision Transformers (ViT) for both GPUs and TPUs, with hyperparams following An Empirical Study of Training Self-Supervised Vision Transformers.
Install pytorch (and its dependencies). Install pytorch xla if running on TPUs.
Finally, install timm for vision transformers: pip3 install timm.
Download ImageNet-1k to a shared directory (e.g. to /checkpoint/ronghanghu/megavlt_paths/imagenet-1k) that can be accessed from all nodes, which should have the following structure.
/checkpoint/ronghanghu/megavlt_paths/imagenet-1k
|_ train
|  |_ <n0......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-N-name>.JPEG
|  |_ ...
|  |_ <n1......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-M-name>.JPEG
|  |  |_...
|  |  |_...
|_ val
|  |_ <n0......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-N-name>.JPEG
|  |_ ...
|  |_ <n1......>
|  |  |_<im-1-name>.JPEG
|  |  |_...
|  |  |_<im-M-name>.JPEG
|  |  |_...
|  |  |_...
Launch the training on GPUs or TPUs as follows.
Make sure SAVE_DIR is a shared directory that can be accessed from all nodes. For TPUs, one can use an NFS directory on GCP.
On GPUs (e.g. using 64 V100 GPUs):
SAVE_DIR="/private/home/ronghanghu/workspace/simclr_vit_release/save_gpu64"
srun \
  --mem=300g --nodes=8 --gres=gpu:8 --partition=learnlab,learnfair \
  --time=4300 --constraint=volta32gb --cpus-per-task=40 \
python3 run_simclr_vit.py \
  world_size=64 \
  ckpt_dir=$SAVE_DIR \
  data_dir=/checkpoint/ronghanghu/megavlt_paths/imagenet-1k
(append use_pytorch_amp=True to the command above to use automatic mixed precision)
On TPUs (e.g. using a v3-256 TPU pod):
SAVE_DIR="/checkpoint/ronghanghu/workspace/simclr_vit_release/save_tpu_v3-256"
TPU_NAME=megavlt-256  # change to your TPU name
# use absolute paths with torch_xla.distributed.xla_dist
sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR  # workaround for permission issue
python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod \
  --env LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 \
  -- \
python3 $(realpath run_simclr_vit.py) \
  device=xla \
  ckpt_dir=$SAVE_DIR \
  data_dir=/checkpoint/ronghanghu/megavlt_paths/imagenet-1k
Suppose the final checkpoint from the previous step is /checkpoint/ronghanghu/workspace/simclr_vit_release/save_tpu_v3-256/simclr_vit_epoch_300.ckpt. Let's evaluate it as follows. Expected linear evaluation accuracy is around 0.739 for both GPUs and TPUs.
Make sure SAVE_DIR is a shared directory that can be accessed from all nodes. For TPUs, one can use an NFS directory on GCP.
On GPUs (e.g. using 64 V100 GPUs):
PRETRAINED_MODEL=/private/home/ronghanghu/workspace/simclr_vit_release/save_gpu64/simclr_vit_epoch_300.ckpt
# SAVE_DIR can be the same or a different directory from SSL training
SAVE_DIR="/private/home/ronghanghu/workspace/simclr_vit_release/save_gpu64"
srun \
  --mem=300g --nodes=8 --gres=gpu:8 --partition=learnlab,learnfair \
  --time=4300 --constraint=volta32gb --cpus-per-task=40 \
python3 $(realpath run_linear_eval_vit.py) \
  world_size=64 \
  ckpt_dir=$SAVE_DIR \
  data_dir=/checkpoint/ronghanghu/megavlt_paths/imagenet-1k \
  linear_eval.pretrained_ckpt_path=$PRETRAINED_MODEL
On TPUs (e.g. using a v3-256 TPU pod):
PRETRAINED_MODEL=/checkpoint/ronghanghu/workspace/simclr_vit_release/save_tpu_v3-256/simclr_vit_epoch_300.ckpt
# SAVE_DIR can be the same or a different directory from SSL training
SAVE_DIR="/checkpoint/ronghanghu/workspace/simclr_vit_release/save_tpu_v3-256"
TPU_NAME=megavlt-256  # change to your TPU name
# use absolute paths with torch_xla.distributed.xla_dist
sudo mkdir -p $SAVE_DIR && sudo chmod -R 777 $SAVE_DIR  # workaround for permission issue
python3 -m torch_xla.distributed.xla_dist \
  --tpu=${TPU_NAME} --restart-tpuvm-pod \
  --env LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 \
  -- \
python3 $(realpath run_linear_eval_vit.py) \
  device=xla \
  ckpt_dir=$SAVE_DIR \
  data_dir=/checkpoint/ronghanghu/megavlt_paths/imagenet-1k \
  linear_eval.pretrained_ckpt_path=$PRETRAINED_MODEL
Following PyTorch XLA performance profiling, on a TPU VM node, one can first start a tensorboard session with tensorboard --logdir . and launch the training scripts below. After the training starts for a while (e.g. after 100 steps when the speed becomes stable), capture the profile from localhost:3294 in the Profile tab of tensorboard.
Run profiling with fake data (no actual data loading) on a single VM node w/ 8 TPU cores:
export PT_XLA_DEBUG=1
export XLA_HLO_DEBUG=1
python3 run_simclr_vit_profiler.py \
  device=xla \
  fake_data=True \
  batch_size=128 lr=0.0  # zero lr to avoid divergence
Run profiling with real data on a single VM node w/ 8 TPU cores:
export PT_XLA_DEBUG=1
export XLA_HLO_DEBUG=1
python3 run_simclr_vit_profiler.py \
  device=xla \
  data_dir=/checkpoint/ronghanghu/megavlt_paths/imagenet-1k \
  batch_size=128 lr=0.0  # zero lr to avoid divergence
Run profiling with fake data but using PyTorch dataloader on a single VM node w/ 8 TPU cores:
export PT_XLA_DEBUG=1
export XLA_HLO_DEBUG=1
python3 run_simclr_vit_profiler_fakewithdataloader.py \
  device=xla \
  fake_data=True \
  batch_size=128 lr=0.0  # zero lr to avoid divergence