Open marvinquiet opened 5 days ago
@marvinquiet,
I think it might have been something to do with your counts being double instead of float? Can you check its dtype?
Dear Dr. Zhong,
Thank you for your reply. I have checked my input and the data type is numpy int64 prior to the model input. I was wondering if I need to perform some data preprocessing (e.g. log1p-norm, center scale) to turn it into Float
?
Sincerely, Wenjing
@marvinquiet you don't need to scale and centre. Can you post the whole error message here?
Sure, my anndata information is
AnnData object with n_obs à n_vars = 8062 à 22496
obs: 'barcodes', 'nCount_RNA', 'nFeature_RNA', 'annotation', 'pixel_x', 'pixel_y', 'timepoint_numeric',
var: 'genes'
uns: 'spatial_neighbors'
obsm: 'spatial'
layers: 'count'
obsp: 'spatial_connectivities', 'spatial_distances'
I am using the following codes to process the data:
# --- run STAMP
sq.gr.spatial_neighbors(adata, n_neighs=round(1 / 1000 * adata.n_obs))
sctm.seed.seed_everything(0)
model = sctm.stamp.STAMP(
adata,
n_topics=40,
layer="count",
time_covariate_keys="timepoint_numeric"
)
model.train(device="cuda:0")
I then encountered the following error:
Computing background frequencies
Computing background frequencies
Computing background frequencies
0%| | 0/800 [00:05<?, ?it/s]
Traceback (most recent call last):
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 191, in __call__
ret = self.fn(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/sctm/model.py", line 496, in model
beta_gp = beta_gp_chole.matmul(beta_gp.unsqueeze(-1)).squeeze(
RuntimeError: expected scalar type Float but found Double
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "run_STAMP.py", line 91, in <module>
model.train(device="cuda:0")
File "/envs/SCTM_py38/lib/python3.8/site-packages/sctm/stamp.py", line 380, in train
batch_loss = svi.step(
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/elbo.py", line 237, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/trace_mean_field_elbo.py", line 82, in _get_trace
model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/trace_elbo.py", line 57, in _get_trace
model_trace, guide_trace = get_importance_trace(
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/infer/enum.py", line 65, in get_importance_trace
model_trace = poutine.trace(
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 216, in get_trace
self(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 198, in __call__
raise exc from e
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py", line 191, in __call__
ret = self.fn(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
File "/envs/SCTM_py38/lib/python3.8/site-packages/sctm/model.py", line 496, in model
beta_gp = beta_gp_chole.matmul(beta_gp.unsqueeze(-1)).squeeze(
RuntimeError: expected scalar type Float but found Double
Trace Shapes:
Param Sites:
stamp$$$caux_loc_unconstrained 1
stamp$$$caux_scale_unconstrained 1
stamp$$$z_topic_lr_loc_unconstrained 40 40
stamp$$$z_topic_lr_scale_unconstrained 40 40
stamp$$$z_topic_diag_loc_unconstrained 1
stamp$$$z_topic_diag_scale_unconstrained 1
stamp$$$delta_loc_unconstrained 22496
stamp$$$delta_scale_unconstrained 22496
stamp$$$bg_loc_unconstrained 22496 3
stamp$$$bg_scale_unconstrained 22496 3
stamp$$$tau_loc_unconstrained 40 1
stamp$$$tau_scale_unconstrained 40 1
stamp$$$lambda_loc_unconstrained 40 22496
stamp$$$lambda_scale_unconstrained 40 22496
stamp$$$batch_tau_loc_unconstrained 40 1
stamp$$$batch_tau_scale_unconstrained 40 1
stamp$$$batch_delta_loc_unconstrained 1 22496
stamp$$$batch_delta_scale_unconstrained 1 22496
stamp$$$beta_loc_unconstrained 40 22496
stamp$$$beta_scale_unconstrained 40 22496
stamp$$$beta_scale_loc_unconstrained 40 22496
stamp$$$beta_scale_scale_unconstrained 40 22496
stamp$$$disp_loc_unconstrained 22496
stamp$$$disp_scale_unconstrained 22496
stamp$$$z_topic_time_loc_unconstrained 40 3
stamp$$$z_topic_time_scale_unconstrained 40 3
stamp$$$z_topic_lr_timescale_loc_unconstrained 40 40
stamp$$$z_topic_lr_timescale_scale_unconstrained 40 40
stamp$$$z_topic_lr_lengthscale_loc_unconstrained 40 1
stamp$$$z_topic_lr_lengthscale_scale_unconstrained 40 1
stamp$$$beta_gp_lengthscale_loc_unconstrained 40 22496
stamp$$$beta_gp_lengthscale_scale_unconstrained 40 22496
stamp$$$beta_gp_mu_loc_unconstrained 40 22496
stamp$$$beta_gp_mu_scale_unconstrained 40 22496
stamp$$$beta_gp_loc_unconstrained 40 22496 3
stamp$$$beta_gp_scale_unconstrained 40 22496 3 3
stamp$$$encoder.base.1.weight 128 44992
stamp$$$encoder.base.1.bias 128
stamp$$$encoder.base.2.weight 128
stamp$$$encoder.base.2.bias 128
stamp$$$encoder.mu_topic.weight 40 128
stamp$$$encoder.mu_topic.bias 40
stamp$$$encoder.norm_topic.0.bias 40
stamp$$$encoder.norm_topic.1.bias 40
stamp$$$encoder.norm_topic.2.bias 40
stamp$$$encoder.diag_topic.weight 40 128
stamp$$$encoder.diag_topic.bias 40
Sample Sites:
sample dist |
value 256 |
caux dist 1 |
value 1 |
delta dist 22496 |
value 22496 |
bg dist 22496 | 3
value 22496 | 3
tau dist 40 1 |
value 40 1 |
lambda_ dist 40 22496 |
value 40 22496 |
z_topic_diag dist 1 |
value 1 |
z_topic_lr dist 40 40 |
value 40 40 |
beta_gp_lengthscale dist 40 22496 |
value 40 22496 |
beta_gp dist 40 22496 | 3
value 40 22496 | 3
Please let me know if additional information is needed.
Thank you so much for your time!
Sincerely, Wenjing
Dear Dr. Zhong,
Thank you for your great work in providing this great tool! I am very interested in your work and am trying it myself. However, I encountered the
RuntimeError: expected scalar type Float but found Double
when runningmodel.train(device="cuda:0")
.I was wondering if you have encountered the same problem before and any suggestions are greatly appreciated!
Sincerely, Wenjing