scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.2k stars 344 forks source link

adding validation ELBO loss for pyro models #1073

Open vitkl opened 3 years ago

vitkl commented 3 years ago

Would be great to have validation set ELBO loss for pyro models. As discussed with Adam here https://github.com/YosefLab/scvi-tools/pull/1059#discussion_r633935024 @adamgayoso I can implement this, but could you point me where this belongs?

vitkl commented 3 years ago

A related warning:

  /Users/vk7/anaconda3/envs/scvi-tools-dev/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: you passed in a val_dataloader but have no validation_step. Skipping val loop
    warnings.warn(*args, **kwargs)
adamgayoso commented 3 years ago

We need to add a function to this class

https://github.com/YosefLab/scvi-tools/blob/14ac97718c7d50470bf0db25aa7dd6d4f4a245c6/scvi/train/_trainingplans.py#L581

with this signature

https://github.com/YosefLab/scvi-tools/blob/14ac97718c7d50470bf0db25aa7dd6d4f4a245c6/scvi/train/_trainingplans.py#L517

that

  1. runs evaluate_loss and then self.log(..) it, or
  2. returns the loss evaluation then additionally write a method with signature

https://github.com/YosefLab/scvi-tools/blob/14ac97718c7d50470bf0db25aa7dd6d4f4a245c6/scvi/train/_trainingplans.py#L563

That sums over the accumulation of the loss. This would be analogous to what we discussed with averaging/summing during minibatching