amansouri3476 / OC-CRL

Object-Centric Causal Representation Learning
4 stars 0 forks source link

Slots suffer from different training objective #1

Closed franciscocms closed 7 months ago

franciscocms commented 7 months ago

Hi there! I am implementing SA-MESH with another synthetic dataset, but I am experiencing some difficulties since slots are converging in a way that all of them bind to all objects (different regions though). My training loss is a sum over the log_prob of the true value of the latent variables in a posterior distribution that is defined with parameters predicted by the model... Is this behavior expected? Thank you in advance!

amansouri3476 commented 7 months ago

Hi!

I'm not sure I quite understand, have you used our code? Or is there a reason why you can't use ours? It closely follows Locatello et al's implementation with MESH regularization from Zhang et al. I've never experienced such behaviour experimenting with 2D and 3D datasets.

franciscocms commented 7 months ago

Hey @amansouri3476, the issue was on the permutation symmetry of the loss when predicting the set of posterior distributions of latent variables: I was forcing a specific slot-object bind strategy that, even though being consistent during training, fixing the slot assignment always ended up with slots distributed over all objects, like a local minimum reached... allowing any arbitrary slot-object binding with Hungarian algorithm over the loss was enough to make it work nicely! thank you for your reply!