broadinstitute / CellCap

Interpret perturbation responses from scRNA-seq perturbation experiments
BSD 3-Clause "New" or "Revised" License
2 stars 0 forks source link

Unbalanced perturbation classes #28

Closed sjfleming closed 8 months ago

sjfleming commented 1 year ago

Try a few things to see what helps the case of very unbalanced datasets (in terms of cells per perturbation).

  1. simulated dataset
  2. try re-weighting the loss function for class balance
  3. try torch's WeightedRandomSampler, adding something like sampler=torch.utils.data.WeightedRandomSampler() to the DataSplitter() call in .train()
sjfleming commented 9 months ago

The current problem:

If we make a simulated dataset (sim1) that has control cells and two perturbation conditions (with separate, non-overlapping perturbation responses), then we see the following:

sjfleming commented 9 months ago

Current attempt:

Use sklearn compute_sample_weight to compute a value for each cell that is inversely proportional to class frequency: https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_sample_weight.html

Call this value w_n where n is cell index.

Our loss function looks schematically something like this

loss = torch.sum(rec_loss_n + kl_divergence_z) + torch.mean(ard_loss_m) + torch.sum(adversarial_loss_n)

where ard_loss_m is only of length m <= n: the number of non-control cells.

NOTE: we might want to take a mean for all terms rather than a sum. If you use sum, then the behavior will change if you change the batch size.

NOTE: look at how the adversarial loss is computed again. Currently the "coding" of perturbations is several-or-zero-hot. A control cell might be coded as [0, 0], while a double-perturbation would be [1, 1]. Then we are using a binary cross entropy loss, BCELoss, and we are summing over all terms (cells and classes given equal weight). But I don't think this is the right way to go.

Currently @ImXman is trying to weight the loss function by multiplying rec_loss_n by w_n.

TODO: try doing this for the loss

rec_loss_n = -generative_outputs["px"].log_prob(x)
kl_divergence_z_n = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale))
adv_loss_n = (
    torch.nn.BCELoss(reduction="none")(inference_outputs["prob"], p)
)

non_control_logic = p.sum(-1) > 0

loss = torch.sum(
    self.rec_weight 
    * (rec_loss_n * w_n 
       + self.beta * kl_divergence_z_n * w_n) 
    + self.lamda * adversarial_loss_n_n * w_n 
) + torch.mean(
    self.ard_kl_weight * kl_divergence_ard_n[non_control_logic] * w_n[non_control_logic] 
) / w_n.sum()

Note the "weighted average" that is different than torch.mean().

We also don't need 3 free params to balance 3 loss terms: we only really need 2. [ratio of ARD to reconstruction, ratio of Adv loss to reconstruction]. Basically define self.rec_weight := 1. We can introduce a third free param: beta for the beta VAE parameter, if we want to. The current defaults would be using self.beta = 0.5.

sjfleming commented 9 months ago

When you try something new, what are the success criteria?

sjfleming commented 9 months ago

Step 2: Further thoughts on WeightedRandomSampler. It could be an alternative to adding weights w_n as above. Or it could be done together!!

Try putting sampler=torch.utils.data.WeightedRandomSampler() with the appropriate input arguments on a line here https://github.com/broadinstitute/CellCap/blob/1b7e1728cb7a6adafd38ca69d351c60f01b19ac3/cellcap/scvi_module.py#L302 inside the DataSplitter call.

sjfleming commented 9 months ago

(I actually think it will be important to do both of the above. Weighting the loss function is probably critical. But then in the very very extreme unbalanced case, you don't want to end up with a lot of minibatches that lack classes entirely. Weighting the loss function won't entirely make up for that. But the WeightedRandomSampler could.)

sjfleming commented 9 months ago

If you get WeightedRandomSampler working, you should try to contribute it back to scvi-tools @ImXman :)

ImXman commented 8 months ago

We have WeightedRandomSampler working. It does solve problem to learn right programs in extremely unbalanced simulated data.

ImXman commented 8 months ago

This issue will be closed

sjfleming commented 8 months ago

Say which pull request closed this (i.e. which pull request added WeightedRandomSampler). Like Closed by #...