scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.25k stars 355 forks source link

Training fails with mini-batch size of one sample #3035

Closed LeonHafner closed 3 days ago

LeonHafner commented 3 weeks ago

This bug has already been reported in several issues and on discourse: #2314 #2214 #221 #426 SCVI training fails if the ceil(n_cells * 0.9) % 128 == 1, where 0.9 is the training split and 128 the batch size. I thought the bug should be fixed with 1.2.0, but unfortunately it still occurs.

It would be very cool, if you could find a fix for that, as that would allow users to remove unnecessary try-except blocks that change the batch size if the error occurs.

Thanks a lot!

import numpy as np
import anndata as ad
from scvi.model import SCVI

num_cells = 143
num_genes = 1000

shape_param = 2.0
scale_param = 1.0

gamma_rates = np.random.gamma(shape=shape_param, scale=scale_param, size=(num_cells, num_genes))
data = np.random.poisson(gamma_rates)

adata = ad.AnnData(X=data)

print(adata.shape)

SCVI.setup_anndata(adata)
model = SCVI(adata)
model.train()
Traceback (most recent call last):
  File "/nfs/home/students/l.hafner/nf-core/scvi_test/test_scvi.py", line 21, in <module>
    model.train()
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/model/base/_training_mixin.py", line 145, in train
    return runner()
           ^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainrunner.py", line 96, in __call__
    self.trainer.fit(self.training_plan, self.data_splitter)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainer.py", line 201, in fit
    super().fit(*args, **kwargs)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/adam.py", line 205, in step
    loss = closure()
           ^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
                     ^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
    step_output = self._step_fn()
                  ^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainingplans.py", line 344, in training_step
    _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainingplans.py", line 278, in forward
    return self.module(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 208, in forward
    return _generic_forward(
           ^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 748, in _generic_forward
    inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 312, in inference
    return self._regular_inference(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/_vae.py", line 377, in _regular_inference
    qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/nn/_base_components.py", line 283, in forward
    q = self.encoder(x, *cat_list)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/nn/_base_components.py", line 173, in forward
    x = layer(x)
        ^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py", line 176, in forward
    return F.batch_norm(
           ^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/functional.py", line 2510, in batch_norm
    _verify_batch_size(input.size())
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/functional.py", line 2478, in _verify_batch_size
    raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])

Versions:

scvi: 1.2.0 anndata: 0.10.9 numpy: 1.26.4

ori-kron-wis commented 3 weeks ago

Hi @LeonHafner thanks for posting this. Currently, apparently, we issue a warning in such cases, see https://github.com/scverse/scvi-tools/pull/2916 and the changelog. There is no logic fix. Will this not be enough for your case?

Otherwise, see my suggested fix here: https://github.com/scverse/scvi-tools/pull/3036 The idea is that in this case we will artificially add +1 to n_train and remove 1 from n_val, if possible (i.e if n_val by itself is larger than 2)

LeonHafner commented 2 weeks ago

Hi @ori-kron-wis, thanks for the quick reply. for me the quickest fix was to pass datasplitter_kwargs={"drop_last": True} to the model.train function. This simply drops the cell that remains in the last batch.

But as this is not a very nice solution, I would appreciate some logic being implemented into scVI to fix this. Your suggested fix is a great idea, hope you will be able to get it merged!

Best, Leon

canergen commented 2 weeks ago

@ori-kron-wis I didn’t come to fix it. I think we should add the cells to validation (and I would do it for less than 3 cells - sounds safer). We should set by default train_size and validation_size to None. If it’s None we change these small batches. If the user sets a custom value like 0.9 (old behavior), we don’t change the train cells and it still fails.