Closed alitinet closed 3 years ago
Model 1: latent space
Model 2: latent space shifted
Model 3: latent
Decided to go for: modality encoders -> mean and variance per modality per group -> calculate joint mean and variance per group -> sample from modality distributions (z_rna1, z_protein1, z_rna2) and the joint distribution for paired data (z_joint1) -> feed z_joint1 + v_rna, z_rna2 + v_rna into mod_dec_rna, feed z_protein1 + v_protein, z_joint1 +v_protein into mod_dec_protein
For integration loss, calculate MMD(z_joint, z_rna2).
Add product of experts models (see http://proceedings.mlr.press/v130/lee21a/lee21a.pdf).
Compare architectures on the following example (same as in #38): 2 datasets, one paired CITE-seq (pair1, x_rna1, x_protein1), one just RNA (pair2, x_rna2). RNA is mod1, protein is mod2. Output of modality encoders for pair1 is z_rna1 and z_protein1, for pair2 - z_rna2. Modality vectors are denoted by v_rna, v_protein.
[x] Model 1: modality encoders -> mean and variance per modality per group -> calculate joint mean and variance per group -> sample from modality distributions (z_rna1, z_protein1, z_rna2) and the joint distribution for paired data (z_joint1) -> feed z_rna1, z_joint1, z_rna2 into mod_dec_rna, feed z_protein1, z_joint1 into mod_dec_protein -> calculate reconstruction loss as sum of recon(mod_dec_rna(z_rna1), x_rna1) + recon(mod_dec_rna(z_joint1), x_rna1) + recon(mod_dec_rna(z_rna2), x_rna2) + recon(mod_dec_protein(z_protein1), x_protein1) + recon(mod_dec_protein(z_joint1), x_protein1). For integration loss, calculate MMD(z_joint, z_rna2).
[x] Model 2: same as model 1 to get z_rna1, z_protein1, z_rna2 and z_joint. Then use vector arithmetic: feed z_joint1 - v_protein, z_rna2 into mod_dec1 and feed z_joint1 - v_rna into mod_dec2. Calculate reconstruction loss as recon(mod_dec_rna(z_joint1 - v_protein), x_rna1) + recon(mod_dec_rna(z_rna2), x_rna2) + recon(mod_dec_protein(z_joint1 - v_rna), x_protein). MMD integration as above.
[x] Model 3: use modality labels instead of vector arithmetic. Same as above to get z_rna1, z_protein1, z_rna2 and z_joint. Then concat z_joint, z_rna2 with modality labels and feed (z_joint, 1, 0), (z_rna2, 1, 0) and (z_joint, 0, 1) into a shared decoder, then feed the first two into mod_dec_rna and the last into mod_dec_protein. Reconstruction and integration as above.