Open sjfleming opened 1 year ago
I have added adversarial loss for covariates. Currently, it only supports one extra covariate.
I think your comment above is about #34 rather than this issue. #34 is about changing the loss function by adding adversarial losses for each covariate. This issue is about changing the optimizer itself.
What we are doing here with the adversarial classifier is really a multi-objective optimization problem.
How well do we want to do on the reconstruction task? How well do we want to do at stripping away perturbation information from
z_basal
?We could write the loss function as
$$ \mathcal{L} = \mathcal{L}_r + \lambda \mathcal{L}_c $$
where
$\mathcal{L}_r$ is the reconstruction loss and $\mathcal{L}_c$ is the classification loss.
These two terms in the loss function are fundamentally separate and competing tasks.
There are probably many ways to handle this. Let's explore them.
Methods:
TrainingPlan
within the context of scvi-tools. But it is complicated! And it would take some finesse. Oh, actually, it might be implemented here! (https://github.com/hav4ik/Hydra/blob/master/src/applications/trainers/mgda.py). I just don't know if it's correct.