Skip to content

zhulinchng/generative-data-augmentation

Repository files navigation

Generative Data Augmentation

This repository contains the code for the experiments in the 'Improving Fine-Grained Image Classification Using Diffusion-Based Generated Synthetic Images' dissertation.

Try the HuggingFace πŸ€— demo

Generative Data Augmentation | Generative Augmented Classifiers

Table of Contents

Data Augmentation with Image-to-Image Stable Diffusion using Prompt Interpolation

The code for generating images using the Stable Diffusion model with prompt interpolation is the imageGen.py script.

Overview of Synthesis Process

To replicate the results, use the following seed values:

  1. Imagenette: 18316237598377439927
  2. Imagewoof: 4796730343513556238
  3. Stanford Dogs: 1127962904372660145

For Imagenette, use the prompt_format A photo of a {class}. For Imagewoof and Stanford Dogs, use the prompt_format A photo of a {class}, a type of dog.

Leave the rest of the parameters as default.

Run the imageGen_trace.py script independently to generate the trace files for the generated images.

Requirements

  • Both Linux and Windows are supported, however Linux is recommended for performance and compatibility reasons.
  • 64-bit Python 3.11.* and PyTorch 2.1.2. See the PyTorch website for installation instructions.
  • The experiments were conducted on: AMD Ryzen 7 5700G, NVIDIA GeForce RTX 3080 12GB, 32GB RAM.

Setup

git clone https://github.com/zhulinchng/generative-data-augmentation.git
cd generative-data-augmentation
# Setup the data directory structure as shown above
conda create --name $env_name python=3.11.* # Replace $env_name with your environment name
conda activate $env_name
# Visit PyTorch website https://pytorch.org/get-started/previous-versions/#v212 for PyTorch installation instructions.
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url # Obtain the correct URL from the PyTorch website
pip install -r requirements.txt
# Configure the imageGen.py script with the target parameters.
torchrun imageGen.py

Files

  • Image Generation generates the images based on the input images from the original dataset and interpolated prompts.
    • Image Generation Trace generates the trace files for the generated images, for analysis and debugging purposes.
  • Classifier Training trains the classifiers on the training set of the desired dataset.
  • Analysis Notebook a series of Jupyter notebooks for the analysis of the generated images and classifiers.
  • Tools contains the helper scripts for the experiments.
    • analysis.py contains the analysis functions for the experiments.
    • classes.py contains the classes for the dataset and the classifiers.
    • data.py contains the data loading functions for the dataset.
    • synth.py contains the functions for the synthetic dataset generation.
    • transform.py contains the transformations for the dataset.
    • utils.py contains the utility scripts for training and evaluation of the classifiers.
  • Results contains the results of the experiments.

Datasets

Original Datasets

  1. Imagenette
  2. Imagewoof
  3. Stanford Dogs

Generated Datasets

  1. Imagenette
  2. Imagewoof
  3. Stanford Dogs

The generated datasets are located in the synthetic folder. The files for the noisy and clean versions can be found in the text files within metadata folder.

Preparing the Dataset

Follow the directory structure below for the datasets:

data/
β”œβ”€β”€ imagenette-original/
β”‚   β”œβ”€β”€ train/
β”‚   β”‚   β”œβ”€β”€ n01440764/
β”‚   β”‚   β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚   β”‚   β”‚   └── ...
β”‚   β”‚   └── ...
β”‚   └── val/
β”‚       β”œβ”€β”€ n01440764/
β”‚       β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚       β”‚   └── ...
β”‚       └── ...
β”œβ”€β”€ imagenette-augmented-noisy/
β”‚   β”œβ”€β”€ train/
β”‚   β”‚   β”œβ”€β”€ n01440764/
β”‚   β”‚   β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚   β”‚   β”‚   └── ...
β”‚   β”‚   └── ...
β”‚   └── val/
β”‚       β”œβ”€β”€ n01440764/
β”‚       β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚       β”‚   └── ...
β”‚       └── ...
β”œβ”€β”€ imagenette-augmented-clean/
β”‚   β”œβ”€β”€ train/
β”‚   β”‚   β”œβ”€β”€ n01440764/
β”‚   β”‚   β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚   β”‚   β”‚   └── ...
β”‚   β”‚   └── ...
β”‚   └── val/
β”‚       β”œβ”€β”€ n01440764/
β”‚       β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚       β”‚   └── ...
β”‚       └── ...    
β”œβ”€β”€ imagenette-synthetic/
β”‚   β”œβ”€β”€ train/
β”‚   β”‚   β”œβ”€β”€ n01440764/
β”‚   β”‚   β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚   β”‚   β”‚   └── ...
β”‚   β”‚   └── ...
β”‚   └── val/
β”‚       β”œβ”€β”€ n01440764/
β”‚       β”‚   β”œβ”€β”€ n01440764_1775.JPEG
β”‚       β”‚   └── ...
β”‚       └── ...
└── ...

Models used in the Experiments

  • The default runwayml/stable-diffusion-v1-5 model is used for generating images for data augmentation, which can be obtained from the HuggingFace model hub.
  • CLIP analysis is performed using the openai/clip-vit-large-patch14 model, which can be obtained from the HuggingFace model hub.

Trained Classifiers

The trained classifiers are available on the HuggingFace model hub: Generative Augmented Classifiers.

Training Overview

In the demo, you have the option to select the classifier to evaluate the generated images and download the model for further evaluation. Alternatively, navigate to the models folder to download the classifiers.

Results

CLIP Similarity Scores between 2 Classes across Datasets

KDE Plot

Imagewoof Similarity Histogram Original

Imagewoof Similarity Histogram Augmented

Similarity Overview

Similarity Overview

Classifier Performance

Model Methods Imagenette Acc@1 Imagewoof Acc@1 Stanford Dogs Acc@1 Stanford Dogs Acc@5
ResNet-18 Original 89.91 83.43 58.12 84.48
ResNet-18 Synthetic (-24.69) 65.22 (-35.17) 48.26 (-44.71) 13.41 (-48.23) 36.25
ResNet-18 Original + Synthetic (Noisy) (+1.55) 91.46 (+1.93) 85.37 (+2.56) 60.69 (+2.34) 86.82
ResNet-18 Original + Synthetic (Clean) (+1.73) 91.64 (+3.18) 86.61 - -
MobileNetV2 Original 92.28 86.77 64.93 89.87
MobileNetV2 Synthetic (-25.15) 67.13 (-40.57) 46.19 (-48.67) 16.26 (-48.28) 41.60
MobileNetV2 Original + Synthetic (Noisy) (-0.13) 92.15 (+0.92) 87.68 (+1.29) 66.22 (+1.32) 91.19
MobileNetV2 Original + Synthetic (Clean) (+0.61) 92.89 (+1.15) 87.91 - -