BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
324 stars 58 forks source link

inconsistent tensor dimensions when giving exact cell numbers #103

Closed carloelle closed 2 years ago

carloelle commented 2 years ago

Dear developers,

I had an issue when I want to give to C2L the exact numbers of cells (calculated by segmentation per spot) per spot. Indeed when I run:

modA1fil = cell2location.models.Cell2location(
    A1_fil, cell_state_df=inf_aver_filraw_A1,
    N_cells_per_location=numcellsA['N_cells'],
    detection_alpha=200
)

modA1fil.train(max_epochs=15000,
          batch_size=None,
          train_size=1,
          use_gpu=True)

with A1_fil is the spatial dataset anndata-formatted with scvi-tools, inf_aver_filraw_A1 is the dataframe coming out from the regression model, numcellsA['N_cells'] is a int-64 vector of the same length of A1_fil's total spots, containing the exact number of cells per each spot.

the error is the following:


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/cell2location/models/_cell2location_module.py in forward(self, x_data, idx, batch_index)
    279         shape = self.ones_1_n_groups * b_s_groups_per_location / self.n_groups_tensor
--> 280         rate = self.ones_1_n_groups / (n_s_cells_per_location / b_s_groups_per_location)
    281         with obs_plate:

RuntimeError: The size of tensor a (50) must match the size of tensor b (3274) at non-singleton dimension 1

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/beegfs/scratch/tmp/ipykernel_70698/2057392656.py in <module>
     10 )
     11 
---> 12 modA1fil.train(max_epochs=15000,
     13           # train using full data (batch_size=None)
     14           batch_size=None,

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/cell2location/models/_cell2location_model.py in train(self, max_epochs, batch_size, train_size, lr, **kwargs)
    181         kwargs["lr"] = lr
    182 
--> 183         super().train(**kwargs)
    184 
    185     def export_posterior(

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py in train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, lr, plan_kwargs, **trainer_kwargs)
    143             **trainer_kwargs,
    144         )
--> 145         return runner()
    146 
    147 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/scvi/train/_trainrunner.py in __call__(self)
     70             self.training_plan.n_obs_training = self.data_splitter.n_train
     71 
---> 72         self.trainer.fit(self.training_plan, self.data_splitter)
     73         self._update_history()
     74 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    175                     message="`LightningModule.configure_optimizers` returned `None`",
    176                 )
--> 177             super().fit(*args, **kwargs)

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    458         )
    459 
--> 460         self._run(model)
    461 
    462         assert self.state.stopped

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    756 
    757         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758         self.dispatch()
    759 
    760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    797             self.accelerator.start_predicting(self)
    798         else:
--> 799             self.accelerator.start_training(self)
    800 
    801     def run_stage(self):

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    807         if self.predicting:
    808             return self.run_predict()
--> 809         return self.run_train()
    810 
    811     def _pre_training_routine(self):

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    855 
    856         # hook
--> 857         self.train_loop.on_train_start()
    858 
    859         try:

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py in on_train_start(self)
     99     def on_train_start(self):
    100         # hook
--> 101         self.trainer.call_hook("on_train_start")
    102 
    103     def on_train_end(self):

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py in call_hook(self, hook_name, *args, **kwargs)
   1226             if hasattr(self, hook_name):
   1227                 trainer_hook = getattr(self, hook_name)
-> 1228                 trainer_hook(*args, **kwargs)
   1229 
   1230             # next call hook in lightningModule

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_start(self)
    150         """Called when the train begins."""
    151         for callback in self.callbacks:
--> 152             callback.on_train_start(self, self.lightning_module)
    153 
    154     def on_train_end(self):

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/scvi/model/base/_pyromixin.py in on_train_start(self, trainer, pl_module)
     45             tens = {k: t.to(pl_module.device) for k, t in tensors.items()}
     46             args, kwargs = pl_module.module._get_fn_args_from_batch(tens)
---> 47             pyro_guide(*args, **kwargs)
     48             break
     49 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/infer/autoguide/guides.py in forward(self, *args, **kwargs)
    529         # if we've never run the model before, do so now so we can inspect the model structure
    530         if self.prototype_trace is None:
--> 531             self._setup_prototype(*args, **kwargs)
    532 
    533         plates = self._create_plates(*args, **kwargs)

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/infer/autoguide/guides.py in _setup_prototype(self, *args, **kwargs)
    479 
    480     def _setup_prototype(self, *args, **kwargs):
--> 481         super()._setup_prototype(*args, **kwargs)
    482 
    483         self._event_dims = {}

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/infer/autoguide/guides.py in _setup_prototype(self, *args, **kwargs)
    169         # run the model so we can inspect its structure
    170         model = poutine.block(self.model, prototype_hide_fn)
--> 171         self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
    172             *args, **kwargs
    173         )

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    178                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    179                 exc = exc.with_traceback(traceback)
--> 180                 raise exc from e
    181             self.msngr.trace.add_node(
    182                 "_RETURN", name="_RETURN", type="return", value=ret

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/opt/common/tools/ric.tiget/anaconda3/envs/Ostuni_cell2loc2/lib/python3.9/site-packages/cell2location/models/_cell2location_module.py in forward(self, x_data, idx, batch_index)
    278         # cell group loadings
    279         shape = self.ones_1_n_groups * b_s_groups_per_location / self.n_groups_tensor
--> 280         rate = self.ones_1_n_groups / (n_s_cells_per_location / b_s_groups_per_location)
    281         with obs_plate:
    282             z_sr_groups_factors = pyro.sample(

RuntimeError: The size of tensor a (50) must match the size of tensor b (3274) at non-singleton dimension 1
               Trace Shapes:                    
                Param Sites:                    
               Sample Sites:                    
               m_g_mean dist           | 1     1
                       value           | 1     1
        m_g_alpha_e_inv dist           | 1     1
                       value           | 1     1
                    m_g dist           | 1 13301
                       value           | 1 13301
 n_s_cells_per_location dist 3274 3274 |        
                       value 3274 3274 |        
b_s_groups_per_location dist 3274    1 |        
                       value 3274    1 |   

This error does not appear if I specify N_cells_per_location=8 (or any other number). How can I proceed? I noticed that providing the exact number of cells per each spot is what you called 'advanced mode', where you also recommend to add 0.1 as pseudocount (though herenumcellsA['N_cells'] is not zero in any spot) and to modify vn in order to make the prior more informative. How can I additionally do that?

Thanks a lot, Best, Carlo

vitkl commented 2 years ago

Hi Carlo

We never saw a benefit from using the exact cell numbers in the current version of cell2location because this value is used as a prior and has a weak effect on the results: there are 2 layers of variables between this value (n_s andz_sr) and estimated cell abundance (w_sf). This might change in future versions. Another issue with cell numbers, especially using them directly rather than as a prior, is that discrete cell number does not account for the proportion of each particular cell captured in the 2D section. This also means that this option is not fully supported and not tested.

That said feel free to compare the results using a single value and the exact cell numbers. N_cells_per_location input needs to be a scalar or np.array with shape (n_location, 1) and data type 'float32'.