JinmiaoChenLab / scTM

scTM: A package for topic modelling in transcriptomics data
https://JinmiaoChenLab.github.io/scTM/
6 stars 0 forks source link

RuntimeError: expected scalar type Float but found Double #5

Open marvinquiet opened 5 days ago

marvinquiet commented 5 days ago

Dear Dr. Zhong,

Thank you for your great work in providing this great tool! I am very interested in your work and am trying it myself. However, I encountered the RuntimeError: expected scalar type Float but found Double when running model.train(device="cuda:0").

I was wondering if you have encountered the same problem before and any suggestions are greatly appreciated!

Sincerely, Wenjing

Chengwei94 commented 5 days ago

@marvinquiet,

I think it might have been something to do with your counts being double instead of float? Can you check its dtype?

marvinquiet commented 4 days ago

Dear Dr. Zhong,

Thank you for your reply. I have checked my input and the data type is numpy int64 prior to the model input. I was wondering if I need to perform some data preprocessing (e.g. log1p-norm, center scale) to turn it into Float?

Sincerely, Wenjing

Chengwei94 commented 3 days ago

@marvinquiet you don't need to scale and centre. Can you post the whole error message here?

marvinquiet commented 3 days ago

Sure, my anndata information is

AnnData object with n_obs à n_vars = 8062 à 22496
    obs: 'barcodes', 'nCount_RNA', 'nFeature_RNA', 'annotation', 'pixel_x', 'pixel_y', 'timepoint_numeric',
    var: 'genes'
    uns: 'spatial_neighbors'
    obsm: 'spatial'
    layers: 'count'
    obsp: 'spatial_connectivities', 'spatial_distances'

I am using the following codes to process the data:

# --- run STAMP
sq.gr.spatial_neighbors(adata, n_neighs=round(1 / 1000 * adata.n_obs))
sctm.seed.seed_everything(0)

model = sctm.stamp.STAMP(
    adata,
    n_topics=40,
    layer="count",
    time_covariate_keys="timepoint_numeric"
)

model.train(device="cuda:0")

I then encountered the following error:

Computing background frequencies
Computing background frequencies
Computing background frequencies
  0%|                                                                                                                  | 0/800 [00:05<?, ?it/s]
Traceback (most recent call last):
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 191, in __call__
    ret = self.fn(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/sctm/model.py", line 496, in model
    beta_gp = beta_gp_chole.matmul(beta_gp.unsqueeze(-1)).squeeze(
RuntimeError: expected scalar type Float but found Double

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "run_STAMP.py", line 91, in <module>
    model.train(device="cuda:0")
  File "/envs/SCTM_py38/lib/python3.8/site-packages/sctm/stamp.py", line 380, in train
    batch_loss = svi.step(
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/elbo.py", line 237, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/trace_mean_field_elbo.py", line 82, in _get_trace
    model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/trace_elbo.py", line 57, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/enum.py", line 65, in get_importance_trace
    model_trace = poutine.trace(
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 216, in get_trace
    self(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 198, in __call__
    raise exc from e
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 191, in __call__
    ret = self.fn(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
  File "/envs/SCTM_py38/lib/python3.8/site-packages/sctm/model.py", line 496, in model
    beta_gp = beta_gp_chole.matmul(beta_gp.unsqueeze(-1)).squeeze(
RuntimeError: expected scalar type Float but found Double
                                     Trace Shapes:
                                      Param Sites:
                    stamp$$$caux_loc_unconstrained                    1
                  stamp$$$caux_scale_unconstrained                    1
              stamp$$$z_topic_lr_loc_unconstrained             40    40
            stamp$$$z_topic_lr_scale_unconstrained             40    40
            stamp$$$z_topic_diag_loc_unconstrained                    1
          stamp$$$z_topic_diag_scale_unconstrained                    1
                   stamp$$$delta_loc_unconstrained                22496
                 stamp$$$delta_scale_unconstrained                22496
                      stamp$$$bg_loc_unconstrained          22496     3
                    stamp$$$bg_scale_unconstrained          22496     3
                     stamp$$$tau_loc_unconstrained             40     1
                   stamp$$$tau_scale_unconstrained             40     1
                  stamp$$$lambda_loc_unconstrained             40 22496
                stamp$$$lambda_scale_unconstrained             40 22496
               stamp$$$batch_tau_loc_unconstrained             40     1
             stamp$$$batch_tau_scale_unconstrained             40     1
             stamp$$$batch_delta_loc_unconstrained              1 22496
           stamp$$$batch_delta_scale_unconstrained              1 22496
                    stamp$$$beta_loc_unconstrained             40 22496
                  stamp$$$beta_scale_unconstrained             40 22496
              stamp$$$beta_scale_loc_unconstrained             40 22496
            stamp$$$beta_scale_scale_unconstrained             40 22496
                    stamp$$$disp_loc_unconstrained                22496
                  stamp$$$disp_scale_unconstrained                22496
            stamp$$$z_topic_time_loc_unconstrained             40     3
          stamp$$$z_topic_time_scale_unconstrained             40     3
    stamp$$$z_topic_lr_timescale_loc_unconstrained             40    40
  stamp$$$z_topic_lr_timescale_scale_unconstrained             40    40
  stamp$$$z_topic_lr_lengthscale_loc_unconstrained             40     1
stamp$$$z_topic_lr_lengthscale_scale_unconstrained             40     1
     stamp$$$beta_gp_lengthscale_loc_unconstrained             40 22496
   stamp$$$beta_gp_lengthscale_scale_unconstrained             40 22496
              stamp$$$beta_gp_mu_loc_unconstrained             40 22496
            stamp$$$beta_gp_mu_scale_unconstrained             40 22496
                 stamp$$$beta_gp_loc_unconstrained       40 22496     3
               stamp$$$beta_gp_scale_unconstrained 40 22496     3     3
                     stamp$$$encoder.base.1.weight            128 44992
                       stamp$$$encoder.base.1.bias                  128
                     stamp$$$encoder.base.2.weight                  128
                       stamp$$$encoder.base.2.bias                  128
                  stamp$$$encoder.mu_topic.weight             40   128
                     stamp$$$encoder.mu_topic.bias                   40
                 stamp$$$encoder.norm_topic.0.bias                   40
                 stamp$$$encoder.norm_topic.1.bias                   40
                 stamp$$$encoder.norm_topic.2.bias                   40
                 stamp$$$encoder.diag_topic.weight             40   128
                   stamp$$$encoder.diag_topic.bias                   40
                                     Sample Sites:
                                       sample dist                    |
                                             value            256     |
                                         caux dist              1     |
                                             value              1     |
                                        delta dist          22496     |
                                             value          22496     |
                                           bg dist          22496     | 3
                                             value          22496     | 3
                                          tau dist       40     1     |
                                             value       40     1     |
                                      lambda_ dist       40 22496     |
                                             value       40 22496     |
                                 z_topic_diag dist              1     |
                                             value              1     |
                                   z_topic_lr dist       40    40     |
                                             value       40    40     |
                          beta_gp_lengthscale dist       40 22496     |
                                             value       40 22496     |
                                      beta_gp dist       40 22496     | 3
                                             value       40 22496     | 3

Please let me know if additional information is needed.

Thank you so much for your time!

Sincerely, Wenjing