phlippe / CITRIS

Code repository of the paper "CITRIS: Causal Identifiability from Temporal Intervened Sequences" and "iCITRIS: Causal Representation Learning for Instantaneous Temporal Effects"
BSD 3-Clause Clear License
50 stars 6 forks source link

Questions about CITRIS-VAE (w/ autoregressive prior) #3

Closed salemohamedo closed 2 years ago

salemohamedo commented 2 years ago

Hi Phillip,

Thanks for sharing the code - I enjoyed the paper! I'm trying to reproduce the CITRIS-VAE (w/ autoregressive prior) experiments as part of a class project and just had a few questions.

  1. Could you please further explain the intuition behind how the autoregressive transition prior works? My understanding is that it takes samples of z_t+1, z_t, I and tries to predict p(z_t+1| z_t, I), which can be factored over the assignment of latent to causal variables. We sample z_t+1 ~ q(z_t+1|x_t+1) (VAE encoder) and then we use these samples, along with z_t, I to model p(z_t+1| z_t, I) and then try to minimize the KL(q(z_t+1|x_t+1), p(z_t+1| z_t, I)). I think the part that confuses me is that we're using z_t+1 samples from the encoder and then feeding them into the autoregressive prior. I would have expected the transition prior to only take z_t, I as input and try to predict z_t+1. I might be misunderstanding something here, so any clarification would be helpful, thanks!

  2. I'm a bit confused about the KL calculation here, which I assume corresponds to the second term in equation 4 of the paper. From looking at the code it seems as though the transition prior predicts the distribution of all latent variables z_t+1 for each causal variable, however, I understood equation 4 as stating that the transition prior should only predict the distribution of the latent variables that have been assigned to a given causal variable.

  3. What is the correct way to configure CITRISVAE to use the autoregressive prior (based on MADE) and not the normalizing flow prior? I saw that the train_vae.py takes a --autoregressive_prior flag, however, in order to get the code to actually use the autoregressive prior I had to modify the code here and set this to False. I'm wondering if I may be misunderstanding how to properly use the code.

Thanks very much! Omar

phlippe commented 2 years ago

Hi @salemohamedo, thanks for your interest in the paper and code! Let me try to answer your questions below:

  1. You are right that the general intuition of the transition prior is to use $z^t$ and $I^{t+1}$ to predict $z^{t+1}$. In most environments, this is also sufficient. The difference occurs when causal variables are multi-dimensional. For instance, in the Temporal Causal3DIdent dataset, the rotation is modeled in 3D (two angles) as well as the position (three dimensions). These dimensions do not necessarily need to be independent, hence the model needs to allow for dependencies between the latent variables that have been assigned to the same causal variable. Note that the autoregressive prior is not a single autoregressive distribution on all latents, i.e. $p(z^{t+1}|z^t,I^{t+1})\neq\prod_i p(zi^{t+1}|z{:i-1}^{t+1},z^t,I^{t+1})$, but instead, for all latent variables that have been assigned to the same causal variable, i.e. $p(z_{\psij}^{t+1}|z^t,I^{t+1})=\prod{i\in\psi_j} p(zi^{t+1}|z{\psi_j\setminus i,...}^{t+1},z^t,I^{t+1})$. So, intuitively speaking, we allow for dependencies between latent variables within the same causal variable. One example where we found the model to use this extensively is when modeling angles. Since angles are circular values (i.e. $0$ and $2\pi$ are the same), it is difficult to model a distribution over it adequately in a single dimension. Instead, the model learns latents similar to $\sin(\phi)$ and $\sin(\phi+\epsilon)$, i.e. shifted sine waves (at least 2 are needed to reconstruct the angle again). Nonetheless, these two sine waves are highly correlated, and modeling the two dimensions independently would not be sufficient. The autoregressive prior, however, allows for such dependencies.
  2. You are right that in Equation 4, we hard-assign each latent variable to a causal variable. However, in the code, this assignment function $\psi$ also needs to learned. This can be done in two ways: either using a Gumbel-Softmax approach where we sample an assignment at each iteration, or marginalizing out over causal variables, i.e. calculate the loss for every possible assignment, and weight it according to the softmax probabilities of $\psi$. We found the latter to give a bit more stable results for negligible additional computational cost (the heavy part are mostly the enc- and decoder architecture).
  3. The Normalizing Flow prior in the VAE gives the model additional modeling strength, similar to https://arxiv.org/abs/1606.04934. This is orthogonal to the autoregressive prior, which can be applied to both the VAE with NF or the standard VAE. We found the NF to be overall better in the environments we tested, hence it is by default active. If you want to deactivate it, then you are right to set the flag for use_flow_prior False.

I hope these answers clarify a bit the intuition behind the VAE prior of CITRIS. Let me know if you have any further questions or something is unclear!

Best, Phillip

salemohamedo commented 2 years ago

Hi Phillip,

Thank you very much! I appreciate the quick and thorough responses - that helps clear up my confusion.

Omar