Closed sjfleming closed 8 months ago
The current problem:
If we make a simulated dataset (sim1) that has control cells and two perturbation conditions (with separate, non-overlapping perturbation responses), then we see the following:
Current attempt:
Use sklearn compute_sample_weight
to compute a value for each cell that is inversely proportional to class frequency:
https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_sample_weight.html
Call this value w_n
where n
is cell index.
Our loss function looks schematically something like this
loss = torch.sum(rec_loss_n + kl_divergence_z) + torch.mean(ard_loss_m) + torch.sum(adversarial_loss_n)
where ard_loss_m
is only of length m <= n
: the number of non-control cells.
NOTE: we might want to take a mean for all terms rather than a sum. If you use sum
, then the behavior will change if you change the batch size.
NOTE: look at how the adversarial loss is computed again. Currently the "coding" of perturbations is several-or-zero-hot. A control cell might be coded as [0, 0], while a double-perturbation would be [1, 1]. Then we are using a binary cross entropy loss, BCELoss, and we are summing over all terms (cells and classes given equal weight). But I don't think this is the right way to go.
p_adv_classifier = p.sum(-1) > 0
Currently @ImXman is trying to weight the loss function by multiplying rec_loss_n
by w_n
.
TODO: try doing this for the loss
rec_loss_n = -generative_outputs["px"].log_prob(x)
kl_divergence_z_n = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale))
adv_loss_n = (
torch.nn.BCELoss(reduction="none")(inference_outputs["prob"], p)
)
non_control_logic = p.sum(-1) > 0
loss = torch.sum(
self.rec_weight
* (rec_loss_n * w_n
+ self.beta * kl_divergence_z_n * w_n)
+ self.lamda * adversarial_loss_n_n * w_n
) + torch.mean(
self.ard_kl_weight * kl_divergence_ard_n[non_control_logic] * w_n[non_control_logic]
) / w_n.sum()
Note the "weighted average" that is different than torch.mean().
We also don't need 3 free params to balance 3 loss terms: we only really need 2. [ratio of ARD to reconstruction, ratio of Adv loss to reconstruction]. Basically define self.rec_weight := 1
. We can introduce a third free param: beta for the beta VAE parameter, if we want to. The current defaults would be using self.beta = 0.5
.
When you try something new, what are the success criteria?
Step 2:
Further thoughts on WeightedRandomSampler
. It could be an alternative to adding weights w_n
as above. Or it could be done together!!
Try putting sampler=torch.utils.data.WeightedRandomSampler()
with the appropriate input arguments on a line here
https://github.com/broadinstitute/CellCap/blob/1b7e1728cb7a6adafd38ca69d351c60f01b19ac3/cellcap/scvi_module.py#L302
inside the DataSplitter
call.
(I actually think it will be important to do both of the above. Weighting the loss function is probably critical. But then in the very very extreme unbalanced case, you don't want to end up with a lot of minibatches that lack classes entirely. Weighting the loss function won't entirely make up for that. But the WeightedRandomSampler could.)
If you get WeightedRandomSampler
working, you should try to contribute it back to scvi-tools
@ImXman :)
We have WeightedRandomSampler working. It does solve problem to learn right programs in extremely unbalanced simulated data.
This issue will be closed
Say which pull request closed this (i.e. which pull request added WeightedRandomSampler). Like
Closed by #...
Try a few things to see what helps the case of very unbalanced datasets (in terms of cells per perturbation).
sampler=torch.utils.data.WeightedRandomSampler()
to theDataSplitter()
call in.train()