BayraktarLab / cell2location

Comprehensive mapping of tissue cell architecture via integrated single cell and spatial transcriptomics (cell2location model)
https://cell2location.readthedocs.io/en/latest/
Apache License 2.0
321 stars 58 forks source link

spatial data model training for bleeding corrected data #287

Open parkjooyoung99 opened 1 year ago

parkjooyoung99 commented 1 year ago

First of all, thank you for making this brilliant tool :)

Bleeding correction of spatial data to capture more precise gene expression signal is used in our lab's spatial data. With this corrected data, I want to do the deconvolution and see if bleeding correction can enhance deconvolution result.
However, while running code, `

train visium model

print('train visium model ')
mod.train(max_epochs=30000,
      # train using full data (batch_size=None)
      batch_size=None,
      # use all data points in training because
      # we need to estimate cell abundance at all locations
      train_size=1,
      use_gpu=True,
     )

i encounter the error below which i think data not following gammapoisson is causing the problem. Would there be a way to adjust the code to corrected data?? ValueError: Error while computing log_prob at site 'data_target': Expected value argument (Tensor of shape (1706, 7637)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution GammaPoisson(), but found invalid values: tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 2.5605, 0.0000], [ 2.8433, 2.5977, 4.5767, ..., 0.0000, 12.8026, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [38.0151, 0.0000, 4.5767, ..., 0.0000, 2.5605, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 2.5605, 0.0000], [ 1.0854, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], device='cuda:0') Trace Shapes:
Param Sites:
Sample Sites:
m_g_mean dist | 1 1 value | 1 1 log_prob |
m_g_alpha_e_inv dist | 1 1 value | 1 1 log_prob |
m_g dist | 1 7637 value | 1 7637 log_prob |
n_s_cells_per_location dist 1706 1 |
value 1706 1 |
log_prob 1706 1 |
b_s_groups_per_location dist 1706 1 |
value 1706 1 |
log_prob 1706 1 |
z_sr_groups_factors dist 1706 50 |
value 1706 50 |
log_prob 1706 50 |
k_r_factors_per_groups dist | 50 1 value | 50 1 log_prob |
x_fr_group2fact dist | 50 43 value | 50 43 log_prob |
w_sf dist 1706 43 |
value 1706 43 |
log_prob 1706 43 |
detection_mean_y_e dist | 1 1 value | 1 1 log_prob |
detection_hyp_prior_alpha dist | 1 1 value | 1 1 log_prob |
detection_y_s dist 1706 1 |
value 1706 1 |
log_prob 1706 1 |
s_g_gene_add_alpha_hyp dist 1 1 |
value 1 1 |
log_prob 1 1 |
s_g_gene_add_mean dist | 1 1 value | 1 1 log_prob |
s_g_gene_add_alpha_e_inv dist | 1 1 value | 1 1 log_prob |
s_g_gene_add dist | 1 7637 value | 1 7637 log_prob |
alpha_g_phi_hyp dist 1 1 |
value 1 1 |
log_prob 1 1 |
alpha_g_inverse dist | 1 7637 value | 1 7637 log_prob |
data_target dist 1706 7637 |
value 1706 7637 | `

Thank you!!

vitkl commented 1 year ago

Hi @parkjooyoung99

Cell2location expects integer counts. It is possible to incorporate background signal predictions ("Bleeding" signal - not corrected data) into the model as follows:

mu_sg = (m_g * (sum_f w_sf * g_fg) + s_eg + {new correction term}_sg) * y_s

This additional data {new correction term}_sg would have to be added to adata.layers, setup_anndata and _get_fn_args_from_batch and forward modified to handle this additional input.

Apologies for the delayed reply. Please let me know if you would like to contribute the above.