Genera1Z / SAVi-PyTorch

SAVi -- Unofficial But 3x Faster Training @ Better Performance. Implementation of ICLR 2022 Paper "Conditional Object-Centric Learning from Video".
4 stars 1 forks source link

Slot attention module question #1

Open MLDeS opened 11 months ago

MLDeS commented 11 months ago

Hey, thanks for the great implementation!

I have two questions for you.

1) Since you have implemented the slot attention, could you say what difference do you find technically between the slot attention and transformers? I am asking mainly about the technical differences (e.g., usage of GRU cell in Slot attention vs no GRU cell in Transformers), and how these differences could affect results and the interpretation of the results conceptually? Or in theory, if you just replace the slot attention with Transformers, would that still work similarly?

2) How is your training faster? What did you change from the original slot attention implementation?

Thanks a lot!

Genera1Z commented 10 months ago
  1. SlotAttention (SA) can be taken as a TransformerBlock(TFB)/TransformerEncoderLayer variant. But here in SAVi model, SA is just used as the object feature extractor/aggregator, while TFB works as the dynamics/transition model, like in Reinforcement Learning World Models, which processes current state (a set of slots here) and predict the future state (a new set of slots here). Namely, at least in SAVi, SA and TFB are two sequential modules with different functionalities. Still, if you really want to compare them, then SA is MultiheadAttention+GRU and iteration, where key = value = image features and query = slots (object features), while TFB is MultiheadAttention+FFN/MLP, where query = key = value = slots. So if you really want to replace SA with TFB, then you have to add a GRU into TFB to include the intermediate state of the iteration, and reorganize the inputs of query/key/value as those in SA inputs.
  2. AMP -> larger batch size -> larger GPU actual FLOPS; high dataset compression -> 5x less disk I/O overhead; PyTorch DataLoader is faster than TensorFlow Dataset (tfds) (actually by removing the shit code in tfds, TF Dataset itself can be even much faster). I didn't change anything in the SA implementation -- I just rewrite the original TensorFlow implementation in PyTorch.