Closed alitinet closed 3 years ago
Old model:
Notes:
New model 1:
New model 2:
Notes:
Hey @M0hammadL, here's some results with the new models, I am not exactly sure why in case of the first (old model) and the third model (new model 2) modality vectors converge to v2 = 1/2 v1, but I think it's because of the way MMD is calculated in the void space.
In new model 1, so when the assumption is that z2 incorporates both modalities even though it's coming from just rna, I don't like that z2 is fed directly into mod_dec1. But if we feed e.g. z2 - v2 into mod_dec1, then the situation becomes weird when we have 3 modalities, e.g. CITE and ATAC. Then we'd have to assume that data coming from CITE-seq has actually all 3 modalities, and then subtract e.g. v_atac from z1, which I think doesn't make much sense.
Decided to go for a PoE architecture with vector arithmetic (#39).
Compare architectures on the following example: 2 datasets, one paired CITE-seq (pair1), one just RNA (pair2). RNA is mod1, protein is mod2. Output of the shared encoder for pair1 is z_1, output for pair2 is z_2. Modality vectors are denoted by v_1, v_2.
[x] older model: modality encoders -> shared encoders with zero masking -> vector arithmetic -> shared decoder -> modality decoders (here vector arithmetic only is relevant for MMD, not for the reconstruction, so integration loss is calculated as mmd(z1-v1-v2, z2-v1))
[x] new model 1: modality encoders -> shared encoders with zero masking -> vector arithmetic (mmd(z1, z2)) -> modality decoders (z1-v2 is fed into mod_dec1, z1-v1 is fed into mod_dec2, z2 is fed directly into mod_dec1)
[x] new model 2: modality encoders -> shared encoders with zero masking -> vector arithmetic (mmd(z1-v1-v2, z2 - v1)) -> modality decoders (z1-v2 is fed into mod_dec1, z1-v1 is fed into mod_dec2, z2 is fed directly into mod_dec1)
[x] new model 3: same as new model 1, but z2-v2 is fed into mod_dec1, i.e. we assume the latent captures information from all modalities.
[x] for each model, visualize output of modality encoders, shared encoder and shared encoder + vector arithmetic.