robert1003 / slot-attention-disentanglement

"Exploring the Role of the Bottleneck in Slot-Based Models Through Covariance Regularization", 2023
https://arxiv.org/abs/2306.02577
7 stars 0 forks source link

Encoder+slot attention siamese network SSL training #15

Open robert1003 opened 1 year ago

robert1003 commented 1 year ago

This issue is for discussing the idea I had in my mind (?) Basically the current architecture of slot attention can be decomposed into two parts: core (encoder+slot attention) and reconstruction head (decoder). The idea is that with slot attention feature vector (slot vector) as bottleneck, the reconstruction objective will force core to encode informative info in the feature vector, and the mechanism of slot attention encourage the disentanglement of features. The combination of these two results in per-object slot vectors as shown in the paper.

What I am thinking is that reconstruction l2 loss is not the only way to force core to encode meaningful info in slot vectors. We can use contrastive methods from SSL as a different way to force core. For example, we can replace the encoder used in MoCo/SimCLR/SimSiam with our encoder+slot attention.

I haven't think through all the details yet, and the following is some problems currently in my mind

Don't hesitate to share thoughts!

robert1003 commented 1 year ago

@as821 can you post the paper that looks at the quality of representations learned by slot attention? Thanks!

abisubramanya27 commented 1 year ago

I had a similar thought in the beginning of the project. But why do you want to differentiate slot vector from SSL feature vectors? The slot-attention feature vector has a role to featurize different instances in an image, we can have the same for contrastive methods as well, which will differentiate two different views or bring together similar ones.

as821 commented 1 year ago

The main paper that looks at the quality of learned slot representations is https://arxiv.org/abs/2107.00637

We will already be implementing the representation evaluation method from this paper for our feature prediction experiment so if we do want to experiments with alternative losses/architectures and compare representation quality, we will already have the evaluation code written.

as821 commented 1 year ago

Generally I think this is an interesting direction. There have been some hints pointing in the direction of reconstruction not necessarily being the correct objective for slot-based models (https://arxiv.org/abs/2107.00637 and https://openreview.net/forum?id=6wbNpKmfEOj demonstrate that low mean squared error correlates well with object disentanglement but that this correlation becomes significantly weaker as scenes become more complex and textured, feature-prediction objective of DINOSAUR, etc).

There are already some works in the direction of applied contrastive learning to slot-based models (https://arxiv.org/abs/2011.10287, https://arxiv.org/pdf/1911.12247.pdf, https://arxiv.org/pdf/2007.09294.pdf) but none of them have been particularly successful. The proposed idea of using multiple views of the same image is a bit different than existing works which usually rely on contrasting consecutive video frames.

SWaV is an SSL method that effectively applies clustering to each view and then compares the clusters (https://arxiv.org/pdf/2006.09882.pdf) which I think can be easily adapted to a slot-based setting. The slot attention module is simply a differentiable clustering mechanism so we can likely apply the same cluster comparison loss as SWaV on the slots of two different views of the same image and get some interesting results.

How should we measure the quality of the slot vectors? A common way is to use classification accuracy of linear classifier trained with the extracted features. This works but not sure if there is a better/different task that we can directly compare it with slot attention paper.

One approach is the feature prediction from slot representations as presented in the paper from my last comment (https://arxiv.org/abs/2107.00637). Feature prediction, if successful, also lets us confirm that each slot captures a different object in its representation. Another option is to train core with some SSL method and then freeze it and train a decoder on top of it. Training a decoder will also give us segmentation masks for each slot.

How should we differentiate slot vector from SSL feature vectors? The slot-attention paper differentiate their slot vector by demonstrating objectiveness in reconstruction. Not sure if we can use the same task or not. This is not a big problem for now, but I think it is a good idea to have expectations(?) for the trained slot vectors.

I think training a slot-based model with SSL has essentially the same value proposition as slot attention does. It offers a way of generating separate representations for each object in the image. Training with SSL removes the ability to generate object masks but we can always train an additional decoder as suggested above, or we can just decide that masks are not really that necessary for whatever downstream task this model will be used for.