broadinstitute / CellCap

Interpret perturbation responses from scRNA-seq perturbation experiments
BSD 3-Clause "New" or "Revised" License
1 stars 0 forks source link

Adversarial classification task as multi-objective optimization #16

Open sjfleming opened 1 year ago

sjfleming commented 1 year ago

What we are doing here with the adversarial classifier is really a multi-objective optimization problem.

How well do we want to do on the reconstruction task? How well do we want to do at stripping away perturbation information from z_basal?

We could write the loss function as

$$ \mathcal{L} = \mathcal{L}_r + \lambda \mathcal{L}_c $$

where

$\mathcal{L}_r$ is the reconstruction loss and $\mathcal{L}_c$ is the classification loss.

These two terms in the loss function are fundamentally separate and competing tasks.

There are probably many ways to handle this. Let's explore them.

Methods:

  1. See what the Theis group did in their paper...
  2. Use the above combined loss function, and try to tune the hyperparamter $\lambda$ and hope for the best
  3. Set an explicit target for the classification task, and try to train in a way that meets that threshold for each "full" gradient update step. I think Luca knew of some framework for doing this. Would be nice to see the paper. I believe the idea was similar to (1), but $\lambda$ becomes a learnable parameter, and its value is set in such a way that it is zero when the target for the classification task is met, and it quickly becomes large when the threshold is crossed. Again, need to find that paper...
  4. Use a fancy approach from the literature. The approach here (https://proceedings.neurips.cc/paper/2018/file/432aca3a1e345e339f35a30c8f65edce-Paper.pdf) seems nice in that it claims to guarantee a Pareto optimal solution, and is not computationally (much?) more expensive than what we are already doing. See "Algorithm 2", and the code here (https://github.com/isl-org/MultiObjectiveOptimization/blob/d45eb262ec61c0dafecebfb69027ff6de280dbb3/multi_task/train_multi_task.py#L113-L184). This could be implemented as a TrainingPlan within the context of scvi-tools. But it is complicated! And it would take some finesse. Oh, actually, it might be implemented here! (https://github.com/hav4ik/Hydra/blob/master/src/applications/trainers/mgda.py). I just don't know if it's correct.
  5. Another literature approach could be GradNorm here (https://arxiv.org/abs/1711.02257), but I don't think it looks as comprehensive as (3), and it doesn't seem to provide the same kinds of guarantees. Actually, I think GradNorm can be used in conjunction with (3).
ImXman commented 7 months ago

I have added adversarial loss for covariates. Currently, it only supports one extra covariate.

sjfleming commented 7 months ago

I think your comment above is about #34 rather than this issue. #34 is about changing the loss function by adding adversarial losses for each covariate. This issue is about changing the optimizer itself.