BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
321 stars 58 forks source link

Error in the model when providing array to N_cells_per_location argument #301

Open Tang-YH opened 1 year ago

Tang-YH commented 1 year ago

Hi Vitalii,

I met an error after I specified an array as the prior estimate of cell abundances per spot. The model seems to try to match the number of spots with 50 (what if I understood correctly is the default number of groups of cell types), and returns an error when they don't match.

# Use array of cell counts from segmentation as prior abundances
abundances_prior = np.array(adata_vis.obsm['features']['segmentation_label']) + 0.1

cell2location.models.Cell2location.setup_anndata(adata=adata_vis)

mod = cell2location.models.Cell2location(
    adata_vis, cell_state_df=inf_aver,
    N_cells_per_location=abundances_prior, # Use segmentation output
    detection_alpha=200
)
Traceback (most recent call last):
  File "./annot_v1_vis.py", line 69, in <module>
    mod.train(max_epochs=10000,
  File "./python3.10/site-packages/cell2location/models/_cell2location_model.py", line 209, in train
    super().train(**kwargs)
  File "./python3.10/site-packages/scvi/model/base/_pyromixin.py", line 146, in train
    return runner()
  File "./python3.10/site-packages/scvi/train/_trainrunner.py", line 81, in __call__
    self.trainer.fit(self.training_plan, self.data_splitter)
  File "./python3.10/site-packages/scvi/train/_trainer.py", line 188, in fit
    super().fit(*args, **kwargs)
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
    self.fit_loop.run()
  File "./python3.10/site-packages/pytorch_lightning/loops/loop.py", line 195, in run
    self.on_run_start(*args, **kwargs)
  File "./python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 222, in on_run_start
    self.trainer._call_callback_hooks("on_train_start")
  File "./python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1597, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "./python3.10/site-packages/scvi/model/base/_pyromixin.py", line 45, in on_train_start
    pyro_guide(*args, **kwargs)
  File "./python3.10/site-packages/pyro/nn/module.py", line 427, in __call__
    return super().__call__(*args, **kwargs)
  File "./python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "./python3.10/site-packages/pyro/infer/autoguide/guides.py", line 510, in forward
    self._setup_prototype(*args, **kwargs)
  File "./python3.10/site-packages/pyro/infer/autoguide/guides.py", line 460, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "./python3.10/site-packages/pyro/infer/autoguide/guides.py", line 157, in _setup_prototype
    self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
  File "./python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "./python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "./python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "./python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "./python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "./python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "./python3.10/site-packages/pyro/nn/module.py", line 427, in __call__
    return super().__call__(*args, **kwargs)
  File "./python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "./python3.10/site-packages/cell2location/models/_cell2location_module.py", line 321, in forward
    rate = self.ones_1_n_groups / (n_s_cells_per_location / b_s_groups_per_location)
RuntimeError: The size of tensor a (50) must match the size of tensor b (11712) at non-singleton dimension 1
               Trace Shapes:
                Param Sites:
               Sample Sites:
               m_g_mean dist             | 1     1
                       value             | 1     1
        m_g_alpha_e_inv dist             | 1     1
                       value             | 1     1
                    m_g dist             | 1 13597
                       value             | 1 13597
 n_s_cells_per_location dist 11712 11712 |
                       value 11712 11712 |
b_s_groups_per_location dist 11712     1 |
                       value 11712     1 |

I specified the parameter N_cells_per_location as the array of cell segmentation label counts, which is a numpy array of length 11712, the same as the total number of spots. The code works if I replace the cell abundance estimate with a single number.

Is there a different way of providing the by spot cell abundance estimates?

Thank you! Yiheng

vitkl commented 1 year ago

Please try providing the array as [n_obs, 1] rather than [n_obs].

Tang-YH commented 1 year ago

Thanks! Though I just encountered another error after I expanded the dimensions:

  File "./python3.10/site-packages/torch/distributions/gamma.py", line 57, in expand
    new.concentration = self.concentration.expand(batch_shape)
RuntimeError: The expanded size of the tensor (1) must match the existing size (11712) at non-singleton dimension 0.  Target sizes: [1, 1].  Tensor sizes: [11712, 1]
               Trace Shapes:
                Param Sites:
               Sample Sites:
               m_g_mean dist          |  1     1
                       value          |  1     1
        m_g_alpha_e_inv dist          |  1     1
                       value          |  1     1
                    m_g dist          |  1 13597
                       value          |  1 13597
 n_s_cells_per_location dist 11712  1 |
                       value 11712  1 |
b_s_groups_per_location dist 11712  1 |
                       value 11712  1 |
    z_sr_groups_factors dist 11712 50 |
                       value 11712 50 |
 k_r_factors_per_groups dist          | 50     1
                       value          | 50     1
        x_fr_group2fact dist          | 50     9
                       value          | 50     9
                   w_sf dist 11712  9 |
                       value 11712  9 |
vitkl commented 1 year ago

Ok, this is unexpected. Let me check what's going on next week.

In general, providing this information has only a limited effect in our experience. It's useful for interpreting the cell abundance results though.

lgaspardboulinc31 commented 1 year ago

Dear Vitalii, dear Tang-YH, I have encountered exactly the same error and as my cell density varies on my samples I would like to try out to input the number of cell per spot recovered by segmentation.

Here is how I prepare my model, by following #103 issue for data format of segmentation prior :

# prepare anndata for cell2location model
cell2location.models.Cell2location.setup_anndata(adata=spatial)

# segmentation prior

seg_prior= np.array(spatial.obs["Nuclei_count"])+0.1 #add pseudocount
seg_prior=seg_prior.reshape(seg_prior.shape[0],1)
seg_prior = seg_prior.astype('float32')
print(seg_prior.shape)
print(seg_prior.dtype)

# create and train the model
mod = cell2location.models.Cell2location(
    spatial, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=seg_prior,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection:
    detection_alpha=20
)

And here is the same error :

RuntimeError: The expanded size of the tensor (1) must match the existing size (2129) at non-singleton dimension 0.  Target sizes: [1, 1].  Tensor sizes: [2129, 1]
               Trace Shapes:                   
                Param Sites:                   
               Sample Sites:                   
               m_g_mean dist         |  1     1
                       value         |  1     1
        m_g_alpha_e_inv dist         |  1     1
                       value         |  1     1
                    m_g dist         |  1 17817
                       value         |  1 17817
 n_s_cells_per_location dist 2129  1 |         
                       value 2129  1 |         
b_s_groups_per_location dist 2129  1 |         
                       value 2129  1 |         
    z_sr_groups_factors dist 2129 50 |         
                       value 2129 50 |         
 k_r_factors_per_groups dist         | 50     1
                       value         | 50     1
        x_fr_group2fact dist         | 50    18
                       value         | 50    18
                   w_sf dist 2129 18 |         
                       value 2129 18 |         

Have you been able to investigate on the issue or troubleshoot it already ?

Thanks a lot for your help and the great tool !

Best