This is a re-implementation of "Object-Centric Learning with Slot Attention" in PyTorch (https://arxiv.org/abs/2006.15055).
Note: the model was run using a Nvidia Tesla V100 16GB GPU.
Run run.sh
to get started. This script will install the dependencies, download the CLEVR dataset and run the model.
python slot_attention/train.py
Modify SlotAttentionParams
in slot_attention/train.py
to modify the hyperparameters. See slot_attenion/params.py
for the default hyperparamters.
To log outputs to wandb, run wandb login YOUR_API_KEY
and set is_logging_enabled=True
in SlotAttentionParams
.
Special thanks to the original authors of the paper: Francesco Locatello, Dirk Weissenborn, Thomas Unterthiner, Aravindh Mahendran, Georg Heigold, Jakob Uszkoreit, Alexey Dosovitskiy, and Thomas Kipf.