This is an implementation of several unsupervised object discovery models (Slot Attention, SLATE, GNM) in PyTorch.
The initial code for this repo was forked from untitled-ai/slot_attention.
- Poetry
- Python >= 3.9
- CUDA enabled computing device
- Clone the repo:
git clone https://github.com/HHousen/slot-attention-pytorch/ && cd slot-attention-pytorch. - Install requirements and activate environment:
poetry installthenpoetry shell. - Download the CLEVR (with masks) dataset (or the original CLEVR dataset by running
./data_scripts/download_clevr.sh /tmp/CLEVR). More details about the datasets are below. - Modify the hyperparameters in object_discovery/params.py to fit your needs. Make sure to change
data_rootto the location of your dataset. - Train a model:
python -m slot_attention.train.
Code to load these models can be adapted from predict.py.
| Model | Dataset | Download |
|---|---|---|
| Slot Attention | CLEVR6 Masks | Hugging Face |
| Slot Attention | Sketchy | Hugging Face |
| GNM | CLEVR6 Masks | Hugging Face |
| Slot Attention | ClevrTex6 | Hugging Face |
| GNM | ClevrTex6 | Hugging Face |
| SLATE | CLEVR6 Masks | Hugging Face |
Train a model by running python -m slot_attention.train.
Hyperparameters can be changed in object_discovery/params.py. training_params has global parameters that apply to all model types. These parameters can be overridden if the same key is present in slot_attention_params or slate_params. Change the global parameter model_type to sa to use Slot Attention (SlotAttentionModel in slot_attention_model.py) or slate to use SLATE (SLATE in slate_model.py). This will determine which model's set of parameters will be merged with training_params.
Perform inference by modifying and running the predict.py script.
Our implementations are based on several open-source repositories.
- Slot Attention ("Object-Centric Learning with Slot Attention"): untitled-ai/slot_attention & Official
- SLATE ("Illiterate DALL-E Learns to Compose"): Official
- GNM ("Generative Neurosymbolic Machines"): karazijal/clevrtex & Official
Select a dataset by changing the dataset parameter in object_discovery/params.py to the name of the dataset: clevr, shapes3d, or ravens. Then, set the data_root parameter to the location of the data. The code for loading supported datasets is in object_discovery/data.py.
- CLEVR: Download by executing download_clevr.sh.
- CLEVR (with masks): Original TFRecords Download / Our HDF5 PyTorch Version.
- This dataset is a regenerated version of CLEVR but with ground-truth segmentation masks. This enables the training script to calculate Adjusted Rand Index (ARI) during validation runs.
- The dataset contains 100,000 images with a resolution of 240x320 pixels. The dataloader splits them 70K train, 15K validation, 15k test. Test images are not used by the object_discovery/train.py script.
- We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the data_scripts/preprocess_clevr_with_masks.py script, which takes approximately 2 hours to execute depending on your machine.
- 3D Shapes: Official Google Cloud Bucket
- RAVENS Robot Data: Train & Test
- We generated a dataset similar in structure to CLEVR (with masks) but of robotic images using RAVENS. Our modified version of RAVENS used to generate the dataset is HHousen/ravens.
- The dataset contains 85,002 images split 70,002 train and 15K validation/test.
- Sketchy: Download and process by following directions in applied-ai-lab/genesis / Download Our Processed Version
- Dataset details are in the paper Scaling data-driven robotics with reward sketching and batch reinforcement learning.
- ClevrTex: Download by executing download_clevrtex.sh. Our dataloader needs to index the entire dataset before training can begin. This can take around 2 hours. Thus, it is recommended to download our pre-made index from this Hugging Face folder and put it in
./data/cache/. - Tetrominoes: Original TFRecords Download / Our HDF5 PyTorch Version.
- There are 1,000,000 samples in the dataset. However, following the Slot Attention paper, we only use the first 60K samples for training.
- We convert the original TFRecords dataset to HDF5 for easy use with PyTorch. This was done using the data_scripts/preprocess_tetrominoes.py script, which takes approximately 2 hours to execute depending on your machine.
To log outputs to wandb, run wandb login YOUR_API_KEY and set is_logging_enabled=True in SlotAttentionParams.
If you use a dataset with ground-truth segmentation masks, then the Adjusted Rand Index (ARI), a clustering similarity score, will be logged for each validation loop. We convert the implementation from deepmind/multi_object_datasets to PyTorch in object_discovery/segmentation_metrics.py.
| Slot Attention CLEVR10 | Slot Attention Sketchy |
|---|---|
![]() |
![]() |
Visualizations (above) for a model trained on CLEVR6 predicting on CLEVR10 (with no increase in number of slots) and a model trained and predicting on Sketchy. The order from left to right of the images is original, reconstruction, raw predicted segmentation mask, processed segmentation mask, and then the slots.
| Slot Attention ClevrTex6 | GNM ClevrTex6 |
|---|---|
![]() |
![]() |
The Slot Attention visualization image order is the same as in the above visualizations. For GNM, the order is original, reconstruction, ground truth segmentation mask, prediction segmentation mask (repeated 4 times).
| SLATE CLEVR6 | GNM CLEVR6 |
|---|---|
![]() |
![]() |
For SLATE, the image order is original, dVAE reconstruction, autoregressive reconstruction, and then the pixels each slot pays attention to.
- untitled-ai/slot_attention: An unofficial implementation of Slot Attention from which this repo was forked.
- Slot Attention: Official Code / "Object-Centric Learning with Slot Attention".
- SLATE: Official Code / "Illiterate DALL-E Learns to Compose".
- IODINE: Official Code / "Multi-Object Representation Learning with Iterative Variational Inference". In the Slot Attention paper, IODINE was frequently used for comparison. The IODINE code was helpful to create this repo.
- Multi-Object Datasets: deepmind/multi_object_datasets. This is the original source of the CLEVR (with masks) dataset.
- Implicit Slot Attention: "Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation". This paper explains a one-line change that improves the optimization of Slot Attention while simultaneously making backpropagation have constant space and time complexity.







