This repository contains the code for implementing and training a View-Decoupled Transformer (VDT) model for Person Re-Identification (ReID) on the multi-view CARGO dataset, addressing the challenges of matching individuals across aerial and ground camera perspectives.
Person ReID in multi-view scenarios (like drone surveillance) is challenging due to drastic viewpoint changes. This project leverages the VDT architecture, which adapts a Vision Transformer (ViT) to explicitly disentangle view-invariant identity features (using a meta token) from view-specific nuisance features (using a view token) through hierarchical subtractive separation within transformer blocks.
- Implementation of the View-Decoupled Transformer (
VDTModelinmodel_vdt.py). - Training script (
train.py) with support for:- Combined ID (Cross-Entropy) and Triplet loss.
- Optional Orthogonal loss to enforce feature decoupling (
--use_orthogonal_loss). - Optional Random Rectangle Occlusion augmentation (
--use_random_rect_occlusion). - Cosine Annealing learning rate scheduler with warmup.
- Checkpointing and resuming training.
- Evaluation scripts:
evaluate_current.py: Evaluates a model trained on CARGO, reports all standard protocols (G<->G, A<->A, A<->G, G->A, ALL), and saves features for visualization.evaluate_external.py: Evaluates a model trained on a different dataset (e.g., Market-1501) on CARGO, handling potential positional embedding mismatches.
- Visualization script (
plot_features.py): Generates t-SNE plots of learned meta and view features to analyze feature decoupling and clustering.
For an in-depth explanation of the model architecture, loss functions, training/evaluation flow, and code structure, please refer to the Technical Explanation Document.
-
Prerequisites:
- Python 3.8+
- PyTorch >= 1.10
- An environment manager like
condaorvenvis recommended.
-
Clone Repository:
git clone <your-repo-url> cd <your-repo-name>
-
Create Environment (Recommended):
# Using conda conda create -n vdt_reid python=3.10 conda activate vdt_reid # Or using venv python -m venv venv source venv/bin/activate # Linux/macOS # venv\Scripts\activate # Windows
-
Install Dependencies:
pip install -r requirements.txt
-
Dataset Preparation (CARGO):
- Download or prepare the CARGO dataset.
- Ensure it follows the standard ReID structure:
<data_path>/ # e.g., /home/user/datasets/CARGO ├── train/ │ ├── 0001/ # Person ID 1 │ │ ├── 0001_c1_f001.jpg │ │ └── ... │ └── 0002/ │ └── ... ├── query/ │ ├── 0001_c2_f002.jpg │ └── ... └── gallery/ ├── 0001_c3_f003.jpg └── ... - Note camera IDs: In CARGO, cameras 1-5 are typically Aerial, and 6-13 are Ground.
Use train.py to train the VDT model.
Example: Start Training
python train.py \
--data_path /path/to/your/CARGO_dataset \
--img_size 224 224 \
--epochs 50 \
--lr 3e-5 \
--warmup_epochs 10 \
--scheduler cosine \
--batch_size 32 \
--use_random_rect_occlusion \
--use_orthogonal_loss \
--loss_ortho_weight 0.5 \
--output_dir vdt_checkpoints --data_path: Required. Path to the CARGO dataset root.--output_dir: Directory to save checkpoints and logs.--use_orthogonal_loss: Flag to enable the orthogonal loss.--loss_ortho_weight: Weight for the orthogonal loss (if enabled).- Other arguments control image size, batch size, learning rate, epochs, etc. (see
config.pyor runpython train.py -h).
Example: Resume Training
python train.py \
--data_path /path/to/your/CARGO_dataset \
--output_dir vdt_checkpoints \
--epochs 70 \
--resume vdt_checkpoints/checkpoint.pt
# Add other args matching the original run if needed--resume: Path to thecheckpoint.ptfile to load model, optimizer, and scheduler state.--epochs: Set the total desired epochs (training will run fromloaded_epoch + 1up toepochs - 1).
Example: Fine-tuning (Resume with Lower LR)
python train.py \
--data_path /path/to/your/CARGO_dataset \
--output_dir vdt_checkpoints \
--epochs 70 \
--lr 1e-5 \
--resume vdt_checkpoints/checkpoint.pt
# Add other args matching the original run if needed- Resumes from the checkpoint but overrides the learning rate to a lower value (e.g.,
1e-5) for fine-tuning.
Two scripts are provided for evaluation:
1. Evaluate Current CARGO Model (evaluate_current.py)
Evaluates the latest checkpoint from a CARGO training run and reports metrics for all protocols (AA, GG, AG, GA, ALL). It also saves features for visualization.
python Showcase_Code/evaluate_current.py \
--data_path /path/to/your/CARGO_dataset \
--img_size 224 224 \
--model_arch vit_base_patch16_224 \
--batch_size 32 \
--device cuda \
--resume vdt_checkpoints/checkpoint.pt # Path to your CARGO checkpoint2. Evaluate External Model (evaluate_external.py)
Evaluates a checkpoint trained on a different dataset (e.g., Market-1501) on the CARGO dataset. Handles potential positional embedding size mismatches.
python Showcase_Code/evaluate_external.py \
--data_path /path/to/your/CARGO_dataset \
--checkpoint_path /path/to/external/model/checkpoint.pt \
--ckpt_num_classes 751 \
--img_size 224 224 \
--model_arch vit_base_patch16_224 \
--batch_size 32 \
--device cuda--checkpoint_path: Required. Path to the external checkpoint.--ckpt_num_classes: Required. Number of ID classes the external checkpoint was trained on (e.g., 751 for Market-1501).
Use plot_features.py to generate t-SNE plots of the learned meta and view features saved by evaluate_current.py.
Prerequisites:
pip install matplotlib scikit-learnRun Plotting:
# Color by view token type (Meta vs. View)
python Showcase_Code/plot_features.py \
--features_path vdt_checkpoints/features_for_plot.pt \
--output_plot_path vdt_checkpoints/feature_plot_view_colored.png \
--color_by view
# Color by person ID
python Showcase_Code/plot_features.py \
--features_path vdt_checkpoints/features_for_plot.pt \
--output_plot_path vdt_checkpoints/feature_plot_id_colored.png \
--color_by idPlots will be saved to the specified output path (defaulting to the vdt_checkpoints directory).
. Showcase_Code/
├── config.py # Argument parsing setup
├── datasets.py # CARGO dataset loading and transforms
├── engine.py # Training loop (train_one_epoch) & evaluation logic (evaluate, compute_metrics)
├── losses.py # Loss function implementations (CE, Triplet, Orthogonal)
├── model_vdt.py # VDTModel and VDTBlock definitions
├── train.py # Main training script
├── evaluate_current.py # Script to evaluate CARGO checkpoints & save features
├── evaluate_external.py # Script to evaluate external checkpoints on CARGO
└── plot_features.py # Script for t-SNE visualization
├── requirements.txt # Python dependencies
├── README.md # This file
└── Technical_Explanation.md # Detailed technical documentation
... (other files like poster.png, vdt_checkpoints/)