theislab / multimil

Multimodal weakly supervised learning to identify disease-specific changes in single-cell atlases
https://multimil.rtfd.io/
BSD 3-Clause "New" or "Revised" License
19 stars 3 forks source link

Add PoE models #39

Closed alitinet closed 3 years ago

alitinet commented 3 years ago

Add product of experts models (see http://proceedings.mlr.press/v130/lee21a/lee21a.pdf).

Compare architectures on the following example (same as in #38): 2 datasets, one paired CITE-seq (pair1, x_rna1, x_protein1), one just RNA (pair2, x_rna2). RNA is mod1, protein is mod2. Output of modality encoders for pair1 is z_rna1 and z_protein1, for pair2 - z_rna2. Modality vectors are denoted by v_rna, v_protein.

alitinet commented 3 years ago

Model 1: latent space joint_poe

alitinet commented 3 years ago

Model 2: latent space joint_poe_mv shifted corrected_poe_mv

alitinet commented 3 years ago

Model 3: latent Screenshot 2021-05-03 at 13 15 20

alitinet commented 3 years ago

Decided to go for: modality encoders -> mean and variance per modality per group -> calculate joint mean and variance per group -> sample from modality distributions (z_rna1, z_protein1, z_rna2) and the joint distribution for paired data (z_joint1) -> feed z_joint1 + v_rna, z_rna2 + v_rna into mod_dec_rna, feed z_protein1 + v_protein, z_joint1 +v_protein into mod_dec_protein

For integration loss, calculate MMD(z_joint, z_rna2).