This project is a challenge to implement an object detection pipeline for the PASCAL VOC 2007 dataset using JAX and Equinox. You'll be adapting key components from the YOLOv10 model (originally in PyTorch) to a JAX-based framework.
This project uses JAX, a high-performance numerical computation library, and Equinox, a library for building neural networks in JAX.
Primary Goal: train an object detection model on the VOC 2007 dataset by implementing the prediction head and loss function inspired by YOLOv10 within the provided JAX/Equinox boilerplate.
- Prerequisites
- Project Structure
- Understanding the VOC 2007 Dataset
- Your Task: Implementing Object Detection
- Crash Course: JAX and Equinox for PyTorch Users
- Data Loading and Preprocessing
- Configuration
- Evaluation (for Object Detection)
- Important Considerations & Debugging Tips
- Resources
- Basic understanding of machine learning and deep learning concepts.
- Experience with Python.
- Familiarity with other DL frameworks like PyTorch is highly beneficial,
- No prior JAX/Equinox experience is strictly necessary,
- Access to a machine with a GPU/TPU is recommended for faster training, though not strictly required.
The project is organized as follows:
.
├── hackathon/
│ ├── init.py
│ ├── augmentations.py # Image augmentation functions (TensorFlow based)
│ ├── classification.py # WORKING EXAMPLE: Image classification task
│ ├── config.py # Configuration loading
│ ├── dataloader.py # TensorFlow Datasets (TFDS) loading logic
│ ├── metrics.py # (Potentially) Metrics like Dice coefficient
│ ├── objectdetection.py # BOILERPLATE: Your primary workspace for object detection
│ ├── preprocessing.py # Image preprocessing functions (TensorFlow based)
│ ├── tasks.py # Task-specific processing function dispatcher
│ └── utils.py # Utility functions
├── classification.py # Main script to run training/evaluation for a classification task
├── objectdetection.py # The main script to update to run the object detection pipeline
├── .config # Example configuration file (rename from .config.example if provided)
└── README.md # This file
hackathon/classification.py
: This file contains a fully working example of an image classification task.hackathon/objectdetection.py
: This is where you'll spend most of your time. It contains boilerplate code for an object detection task, with severalNotImplementedError
sections that you need to fill in.augmentations.py
&preprocessing.py
: These files use TensorFlow for image operations. Note that the data pipeline (loading, preprocessing, augmentation) is TensorFlow-based, and the resulting NumPy arrays are then fed into JAX for model training.config.py
: Manages project configurations. See the.config
file for available options.
The base neural network used for feature extraction is a custom Convolutional Neural Network (CNN). Its definition is another open source project, Equimo, a project implementing multiple computer vision models using Equinox.
The PASCAL VOC 2007 dataset is a standard benchmark for object detection. Key characteristics:
- Images: Contains thousands of color images of various scenes.
- Object Classes: 20 object classes (e.g., person, car, cat, dog, chair).
- Annotations: Each image is annotated with:
- Bounding boxes for every object instance.
- The class label for each bounding box.
- Challenges:
- Multiple objects per image.
- Objects of varying sizes and scales.
- Occlusion and truncation of objects.
Your model will need to predict the class and bounding box coordinates (typically (x_{min}, y_{min}, x_{max}, y_{max})) for each object in an image.
Your primary task is to complete the hackathon/objectdetection.py
script. This involves:
-
Understanding the Data:
- Familiarize yourself with how bounding boxes and labels are provided by the
dataloader.py
for the VOC dataset. Thepreprocessing
function inobjectdetection.py
receivessample[objects_key]["bbox"]
andsample[objects_key]["label"]
. You'll need to process these appropriately. - The
augmentations.py
file contains acreate_global_crops
function that has aNotImplementedError
section for handling bounding boxes during augmentations. You will need to implement this to ensure bounding boxes are correctly transformed along with the images. Therandom_resized_crop
function already has bbox transformation logic you can adapt.
- Familiarize yourself with how bounding boxes and labels are provided by the
-
Implementing the YOLOv10-inspired Prediction Head:
- In
objectdetection.py
, theloss_fn
callsmodel(...)
. Thismodel
(your custom CNN + prediction head) will output predictions. You need to design and implement a prediction head that takes features from the base CNN and outputs predictions in a format suitable for object detection following the Yolov10 network. - Refer to the YOLOv10 architecture for inspiration on how the prediction head is structured. You'll need to
implement these layers using Equinox modules (
eqx.nn.Conv2d
,eqx.nn.Linear
, etc.). You can check out the Equimo repository to have many examples of advanced types of layers.
- In
-
Implementing the YOLOv10-inspired Loss Function:
- The
loss_fn
inobjectdetection.py
currently has aNotImplementedError
. You need to implement the loss calculation. This will likely involve several components:- Classification Loss: For the class of detected objects.
- Regression Loss: For the bounding box coordinates.
- Objectness Loss: To determine if an anchor/grid cell contains an object.
- You'll need to match ground truth bounding boxes with predicted boxes to calculate these losses. This is a critical part of object detection.
- The
-
Completing
preprocessing
andaugmentations
:- The
preprocessing
function inobjectdetection.py
needs to be completed to handle bounding box data alongside images. This might involve converting bounding box formats or preparing them for augmentation. - The
augmentations
function inobjectdetection.py
also needs completion. Crucially, thecreate_global_crops
call within it needs to correctly handle bounding box transformations. Thebboxes
argument tocreate_global_crops
(inaugmentations.py
) is where you'll pass your bounding box data, and you'll need to implement the logic withincreate_global_crops
to transform these bboxes according to the image augmentations.
- The
-
Implementing Evaluation:
- The evaluation loop in
evaluate
withinobjectdetection.py
has aNotImplementedError
. You'll need to implement a suitable evaluation metric for object detection, such as mean Average Precision (mAP). This will involve processing model predictions on the validation set and comparing them against ground truth.
- The evaluation loop in
You are not expected to port the entire YOLOv10 model. The focus is on the prediction heads and the loss computation.
- YOLOv10 Repository: https://github.com/THU-MIG/yolov10
- Study the PyTorch implementation of the heads and loss functions in the YOLOv10 repository. Your goal is to translate the logic and mathematical operations into JAX/Equinox.
hackathon/objectdetection.py
:preprocessing()
: Implement logic to handle image and object (bounding box, label) data.augmentations()
: Ensure bounding boxes are correctly passed to and processed bycreate_global_crops
.loss_fn()
: Implement the object detection loss.evaluate()
: Implement the evaluation loop and mAP calculation.- You may also need to define new Equinox modules for the prediction head(s) here or in a separate file.
hackathon/augmentations.py
:create_global_crops()
: Implement the bounding box transformation logic whenbboxes
are provided.
- Start Simple: Begin by understanding the structure of a single detection head layer.
- Equinox Equivalents:
- PyTorch
nn.Conv2d
->eqx.nn.Conv2d
- PyTorch
nn.Linear
->eqx.nn.Linear
- Activation functions (ReLU, Sigmoid, etc.) are available in
jax.nn
.
- PyTorch
- Shape Management: Pay very close attention to tensor shapes at each step. JAX is explicit about this, i.e., it will not explicitely broadcast shapes like PyTorch.
- Loss Components: Break down the YOLO loss into its constituent parts and implement them one by one.
JAX and Equinox have some fundamental differences from PyTorch. Understanding these is crucial.
- Functional Programming: JAX encourages a functional programming style. Functions should ideally be pure (no side effects; output depends only on input).
- Immutability: JAX arrays (and Equinox model parameters) are immutable. Operations that "modify" an array actually return a new array.
Unlike PyTorch's global random seed, JAX requires explicit handling of pseudo-random number generator (PRNG) keys.
- Create a key:
key = jax.random.PRNGKey(seed)
- Split keys for different operations:
key, subkey = jax.random.split(key)
- Pass keys to functions that require randomness (e.g., initializers, dropout, sampling).
- Why? This makes random operations reproducible and easier to reason about, especially with transformations like
jax.vmap
andjax.pmap
.
Example:
import jax
import jax.random as jr
key = jr.PRNGKey(0)
key_dropout, key_init, key_data_aug = jr.split(key, 3)
# Use key_dropout for a dropout layer
# Use key_init for weight initialization
# Use key_data_aug for a random data augmentation
In our codebase, you'll see key
being passed around, especially in train_step
and compute_cls_accuracy
.
JAX can compile your Python functions into highly optimized XLA (Accelerated Linear Algebra) code using jax.jit
.
- Usage: Decorate your function with
@jax.jit
(or@eqx.filter_jit
for Equinox modules). - Benefits: Significant speedups, especially for numerical computations.
- Gotchas:
- The function is traced on its first call with specific input shapes and types. Subsequent calls with different shapes/types will trigger a re-compilation.
- Control flow (if/else, loops) based on tensor values can be tricky. Try to make control flow depend on static values or use JAX control flow primitives like
jax.lax.cond
orjax.lax.scan
. - Side effects (like printing or modifying external variables) inside JITted functions behave unexpectedly or are ignored. Use
jax.debug.print
for debugging.
jax.vmap
is a powerful transformation that maps a function designed for single data points to work over batches of data.
- PyTorch: Batching is often implicit in layers (e.g.,
nn.Conv2d
expects a batch dimension). - JAX/Equinox: You typically write functions/modules for a single data sample.
jax.vmap
then handles the batching.jax.vmap(model)(batch_of_images)
in_axes
argument specifies which arguments to map over (e.g.,in_axes=(0, None)
means map over the first dimension of the first argument, and broadcast the second argument).
Example from classification.py
:
# model is defined to process a single image
logits = jax.vmap(model, in_axes=(0, None))(images, key)
# Here, `images` is a batch of images (e.g., shape [B, H, W, C])
# `key` is a single PRNG key, broadcasted for each image in the batch.
# `model` is applied to each image in `images`.
You must use jax.vmap
(or ensure your model inherently handles batches, which is less common in Equinox) when
processing batches of data.
Equinox provides a simple way to create neural network modules, similar to torch.nn.Module
.
- Structure: Parameters (weights, biases) are attributes of the class.
eqx.filter
: Equinox modules are "PyTrees."eqx.filter
helps separate learnable parameters from static configuration or non-learnable components.params, static = eqx.filter(model, eqx.is_array)
separates array attributes (learnable parameters) from everything else.model = eqx.combine(params, static)
reconstructs the model. This is crucial for optimizers, which only need to update the learnable parameters.
Example from classification.py
:
# In train_step:
model = eqx.combine(params, static) # Reconstruct model
# ...
loss, grads = jax.value_and_grad(loss_for_step)(params) # Get grads only for params
# params = optax.apply_updates(params, updates) # Update only params
JAX provides functions for automatic differentiation.
jax.grad(fun)
: Returns a function that computes the gradient offun
.jax.value_and_grad(fun)
: Returns a function that computes both the value offun
and its gradient. This is often more efficient.
The loss_fn
in the provided code is designed to take params
(the learnable parts of the model) as its first argument,
which is what jax.grad
or jax.value_and_grad
will differentiate with respect to.
JAX and Equinox extensively use the concept of "PyTrees." A PyTree is a container of leaf elements, like lists, tuples,
and dictionaries. JAX transformations (jit
, vmap
, grad
) can operate on entire PyTrees. Equinox modules are PyTrees.
This allows you to handle complex nested structures of parameters and data seamlessly.
Since JAX arrays are immutable, you don't update them in-place.
# PyTorch (in-place)
params[0] += update_val
# JAX (functional update)
params = params.at[0].add(update_val)
# Or, for entire arrays, often just:
params = new_params
Optax optimizers handle this correctly when applying updates: params = optax.apply_updates(params, updates)
.
hackathon/dataloader.py
: Usestensorflow_datasets
(TFDS) to load data.- The
load_tfds
function handles dataset loading, preprocessing, augmentation, batching, and conversion to NumPy arrays.
- The
hackathon/preprocessing.py
: Contains TensorFlow-based image preprocessing functions likeprepare_image
(resizing) andnormalize
.hackathon/augmentations.py
: Contains TensorFlow-based image augmentation functions likerandom_resized_crop
,random_horizontal_flip
,color_jitter
, etc.- Crucially for object detection: The
random_resized_crop
function has logic to transform bounding boxes. You will need to ensure this logic is correctly used and potentially adapted when you implement bounding box handling increate_global_crops
and theaugmentations
function inobjectdetection.py
.
- Crucially for object detection: The
hackathon/tasks.py
: Theget_processing_functions
utility dynamically selects the correctpreprocessing
,augmentations
, andpostprocessing
functions based on the dataset and task type.
The pipeline is: TFDS -> TensorFlow preprocessing/augmentations -> NumPy arrays -> JAX model.
Project configurations are managed by hackathon/config.py
using the everett
library, which reads from environment
variables or a .config
file.
An example .config
file might look like:
seed=42
# Object Detection - VOC
num_classes_voc=20 # Number of classes in VOC
image_size_voc=416 # Example input size for YOLO-like models
batch_size_voc=16
learning_rate_voc=1e-4
warmup_steps_voc=100
num_epochs_voc=50
# ... other VOC specific params ...
# image params
patch_size=16 # If your base CNN is ViT-like
normalization_mean=0.485,0.456,0.406
normalization_std=0.229,0.224,0.225
global_crops_size=224
# ...
Make sure your .config
file has the necessary parameters for voc
(e.g., num_classes_voc
, image_size_voc
,
batch_size_voc_scratch
, etc.). The provided config.py
lists all expected configurations.
For object detection, a common evaluation metric is mean Average Precision (mAP).
- You will need to implement mAP calculation in the
evaluate
function ofobjectdetection.py
. - This involves:
- Getting predictions from your model on the validation set (bounding boxes, class scores, objectness scores).
- Applying Non-Maximum Suppression (NMS) to filter redundant boxes.
- Comparing predicted boxes to ground truth boxes using Intersection over Union (IoU).
- Calculating precision and recall for each class at various confidence thresholds.
- Computing the Average Precision (AP) for each class (area under the precision-recall curve).
- Averaging APs across all classes to get mAP.
This is a non-trivial task, so plan your time accordingly. You might find existing JAX implementations of mAP or its components helpful as a reference.
- Start with
classification.py
: Understand it thoroughly. It's your reference for JAX/Equinox patterns. - Incremental Development: Implement and test small pieces at a time.
- Can you get the data shapes right for the head?
- Can you compute one component of the loss?
- Debugging JITted Code:
jax.debug.print("message: {x}", x=array)
: Prints values inside JITted functions. It does not print during tracing, only during execution.jax.debug.breakpoint()
: If you run your code under a Python debugger.- Temporarily disable
@jax.jit
or@eqx.filter_jit
for easier debugging with standard Python print statements, but remember performance will be much slower.
- Shape Mismatches: This will be a common source of errors. Print shapes frequently!
print(array.shape)
(outside JIT)jax.debug.print("array shape: {s}", s=array.shape)
(inside JIT)
- PRNG Key Management: Ensure keys are split and threaded through your functions correctly. Reusing keys will lead to correlated "random" numbers.
- Bounding Box Formats: Be consistent with your bounding box coordinate format (e.g.,
[xmin, ymin, xmax, ymax]
vs.[x_center, y_center, width, height]
, normalized vs. absolute). VOC typically uses[ymin, xmin, ymax, xmax]
(absolute). TFDS might provide them normalized. Checkdataloader.py
andpreprocessing.py
. - Numerical Stability: Watch out for
NaN
s in your loss. This can be due to log(0), division by zero, or large gradients. Usejax.numpy.clip
or add small epsilon values (1e-7
orjnp.finfo(dtype).eps
) where appropriate. - TensorFlow vs. JAX operations: Remember that the data loading pipeline uses TensorFlow ops. Your model and loss function will use JAX ops (
jax.numpy
which is often aliased asjnp
).
- JAX Documentation: https://jax.readthedocs.io/
- Equinox Documentation: https://docs.kidger.site/equinox/
- Optax (Optimizer library) Documentation: https://optax.readthedocs.io/
- YOLOv10 Paper & Repository: (Linked above)
- PASCAL VOC Dataset: http://host.robots.ox.ac.uk/pascal/VOC/
- TFDS VOC2007: https://www.tensorflow.org/datasets/catalog/voc#voc2007 (Details on how TFDS structures the data)