Open Rayucas21 opened 3 hours ago
I tried filtering each data separately, finding the intersection genes, and then merging them, and adjusted the learning rate to 0.001, but it still gave an error, as shown below
model.train(device="cpu", sampler = "W",learning_rate=0.001)
ValueError Traceback (most recent call last) File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:191, in TraceHandler.call(self, *args, *kwargs) 190 try: --> 191 ret = self.fn(args, **kwargs) 192 except (ValueError, RuntimeError) as e:
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/messenger.py:32, in _context_wrap(context, fn, *args, *kwargs) 31 with context: ---> 32 return fn(args, **kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/sctm/model.py:704, in spatialLDAModel.guide(self, x, sgc_x, categorical_covariate_code, time_covariate_code, not_cov, sample_idx, mask) 701 with poutine.scale(scale=kl_weight): 702 pyro.sample( 703 "z_topic", --> 704 dist.Normal(z_topic_loc, z_topic_scale).to_event(1), 705 ) 706 else:
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/distributions/distribution.py:26, in DistributionMeta.call(cls, *args, *kwargs) 25 return result ---> 26 return super().call(args, **kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/normal.py:57, in Normal.init(self, loc, scale, validate_args) 56 batch_shape = self.loc.size() ---> 57 super().init(batch_shape, validate_args=validate_args)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/distribution.py:70, in Distribution.init(self, batch_shape, event_shape, validate_args) 69 if not valid.all(): ---> 70 raise ValueError( 71 f"Expected parameter {param} " 72 f"({type(value).name} of shape {tuple(value.shape)}) " 73 f"of distribution {repr(self)} " 74 f"to satisfy the constraint {repr(constraint)}, " 75 f"but found invalid values:\n{value}" 76 ) 77 super().init()
ValueError: Expected parameter loc (Tensor of shape (256, 18)) of distribution Normal(loc: torch.Size([256, 18]), scale: torch.Size([256, 18])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], grad_fn=
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last) Cell In[197], line 2 1 # We used a weighted sampler here as that the first timepoint is extremly small compared to the last timepoint. ----> 2 model.train(device="cpu", sampler = "W",learning_rate=0.001)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/sctm/stamp.py:380, in STAMP.train(self, max_epochs, min_epochs, learning_rate, betas, not_cov_epochs, device, batch_size, sampler, weight_decay, iterations_to_anneal, min_kl, max_kl, early_stop, patience, shuffle, num_particles) 378 # optimizer.zerograd() 379 for , batch in enumerate(self.dataloader): --> 380 batch_loss = svi.step( 381 batch["x"].to(device), 382 batch["sgc_x"].to(device), 383 batch["categorical_covariate_codes"].to(device), 384 ( 385 batch["time_covariate_codes"].to(device) 386 if self.n_time >= 2 387 else None 388 ), 389 not_cov, 390 batch["sample_idx"], 391 True, 392 ) 393 losses.append(float(batch_loss)) 394 # iteration += 1 395 # print(f"Full time{end - start}
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, *kwargs) 143 # get loss and compute gradients 144 with poutine.trace(param_only=True) as param_capture: --> 145 loss = self.loss_and_grads(self.model, self.guide, args, **kwargs) 147 params = set( 148 site["value"].unconstrained() for site in param_capture.trace.nodes.values() 149 ) 151 # actually perform gradient steps 152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs) 138 loss = 0.0 139 # grab a trace from the generator --> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): 141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( 142 model_trace, guide_trace 143 ) 144 loss += loss_particle / self.num_particles
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs) 235 else: 236 for i in range(self.num_particles): --> 237 yield self._get_trace(model, guide, args, kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/trace_mean_field_elbo.py:82, in TraceMeanField_ELBO._get_trace(self, model, guide, args, kwargs) 81 def _get_trace(self, model, guide, args, kwargs): ---> 82 model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) 83 if is_validation_enabled(): 84 _check_mean_field_requirement(model_trace, guide_trace)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs) 52 def _get_trace(self, model, guide, args, kwargs): 53 """ 54 Returns a single trace from the guide, and the model that is run 55 against it. 56 """ ---> 57 model_trace, guide_trace = get_importance_trace( 58 "flat", self.max_plate_nesting, model, guide, args, kwargs 59 ) 60 if is_validation_enabled(): 61 check_if_enumerated(guide_trace)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/enum.py:60, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach) 58 model_trace, guide_trace = unwrapped_guide.get_traces() 59 else: ---> 60 guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace( 61 *args, **kwargs 62 ) 63 if detach: 64 guidetrace.detach()
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:216, in TraceHandler.get_trace(self, *args, kwargs) 208 def get_trace(self, *args, *kwargs) -> Trace: 209 """ 210 :returns: data structure 211 :rtype: pyro.poutine.Trace (...) 214 Calls this poutine and returns its trace instead of the function's return value. 215 """ --> 216 self(args, kwargs) 217 return self.msngr.get_trace()
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.call(self, *args, **kwargs) 196 exc = exc_type("{}\n{}".format(exc_value, shapes)) 197 exc = exc.with_traceback(traceback) --> 198 raise exc from e 199 self.msngr.trace.add_node( 200 "_RETURN", name="_RETURN", type="return", value=ret 201 ) 202 return ret
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:191, in TraceHandler.call(self, *args, *kwargs) 187 self.msngr.trace.add_node( 188 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 189 ) 190 try: --> 191 ret = self.fn(args, **kwargs) 192 except (ValueError, RuntimeError) as e: 193 exc_type, exc_value, traceback = sys.exc_info()
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/messenger.py:32, in _context_wrap(context, fn, *args, kwargs) 25 def _context_wrap( 26 context: "Messenger", 27 fn: Callable, 28 *args: Any, 29 *kwargs: Any, 30 ) -> Any: 31 with context: ---> 32 return fn(args, kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/sctm/model.py:704, in spatialLDAModel.guide(self, x, sgc_x, categorical_covariate_code, time_covariate_code, not_cov, sample_idx, mask) 698 z_topic_loc, z_topic_scale = self.encoder( 699 sgc_x, categorical_covariate_code 700 ) 701 with poutine.scale(scale=kl_weight): 702 pyro.sample( 703 "z_topic", --> 704 dist.Normal(z_topic_loc, z_topic_scale).to_event(1), 705 ) 706 else: 707 z_topic_concent = self.encoder(sgc_x, categorical_covariate_code)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/distributions/distribution.py:26, in DistributionMeta.call(cls, *args, *kwargs) 24 if result is not None: 25 return result ---> 26 return super().call(args, **kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/normal.py:57, in Normal.init(self, loc, scale, validate_args) 55 else: 56 batch_shape = self.loc.size() ---> 57 super().init(batch_shape, validate_args=validate_args)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/distribution.py:70, in Distribution.init(self, batch_shape, event_shape, validate_args) 68 valid = constraint.check(value) 69 if not valid.all(): ---> 70 raise ValueError( 71 f"Expected parameter {param} " 72 f"({type(value).name} of shape {tuple(value.shape)}) " 73 f"of distribution {repr(self)} " 74 f"to satisfy the constraint {repr(constraint)}, " 75 f"but found invalid values:\n{value}" 76 ) 77 super().init()
ValueError: Expected parameter loc (Tensor of shape (256, 18)) of distribution Normal(loc: torch.Size([256, 18]), scale: torch.Size([256, 18])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], gradfn=
Param Sites:
Sample Sites:
sample dist |
value 256 |
caux dist 1 |
value 1 |
delta dist 1968 |
value 1968 |
bg dist 1968 | 4
value 1968 | 4
tau dist 18 1 |
value 18 1 |
lambda
value 18 1968 |
z_topic_diag dist 1 |
value 1 |
z_topic_lr dist 18 18 |
value 18 18 |
beta_gp_lengthscale dist 18 1968 |
value 18 1968 |
beta_gp dist 18 1968 | 4
value 18 1968 | 4
disp dist 1968 |
value 1968 |
Thank you for developing the STAMP.
I have 4 stereo-seq chip data and want to do time series analysis. However, when running to this step, model.train(device="cpu", sampler = "W") keeps reporting an error. The error is as follows. I checked my data and found no NA, so I would like to ask how to solve it. Thanks
0%| | 0/800 [00:00<?, ?it/s]
ValueError Traceback (most recent call last) File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:191, in TraceHandler.call(self, *args, *kwargs) 190 try: --> 191 ret = self.fn(args, **kwargs) 192 except (ValueError, RuntimeError) as e:
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/messenger.py:32, in _context_wrap(context, fn, *args, *kwargs) 31 with context: ---> 32 return fn(args, **kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/sctm/model.py:704, in spatialLDAModel.guide(self, x, sgc_x, categorical_covariate_code, time_covariate_code, not_cov, sample_idx, mask) 701 with poutine.scale(scale=kl_weight): 702 pyro.sample( 703 "z_topic", --> 704 dist.Normal(z_topic_loc, z_topic_scale).to_event(1), 705 ) 706 else:
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/distributions/distribution.py:26, in DistributionMeta.call(cls, *args, *kwargs) 25 return result ---> 26 return super().call(args, **kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/normal.py:57, in Normal.init(self, loc, scale, validate_args) 56 batch_shape = self.loc.size() ---> 57 super().init(batch_shape, validate_args=validate_args)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/distribution.py:70, in Distribution.init(self, batch_shape, event_shape, validate_args) 69 if not valid.all(): ---> 70 raise ValueError( 71 f"Expected parameter {param} " 72 f"({type(value).name} of shape {tuple(value.shape)}) " 73 f"of distribution {repr(self)} " 74 f"to satisfy the constraint {repr(constraint)}, " 75 f"but found invalid values:\n{value}" 76 ) 77 super().init()
ValueError: Expected parameter loc (Tensor of shape (256, 18)) of distribution Normal(loc: torch.Size([256, 18]), scale: torch.Size([256, 18])) to satisfy the constraint Real(), but found invalid values: tensor([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]], grad_fn=)
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last) Cell In[153], line 2 1 # We used a weighted sampler here as that the first timepoint is extremly small compared to the last timepoint. ----> 2 model.train(device="cpu", sampler = "W")
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/sctm/stamp.py:380, in STAMP.train(self, max_epochs, min_epochs, learning_rate, betas, not_cov_epochs, device, batch_size, sampler, weight_decay, iterations_to_anneal, min_kl, max_kl, early_stop, patience, shuffle, num_particles) 378 # optimizer.zerograd() 379 for , batch in enumerate(self.dataloader): --> 380 batch_loss = svi.step( 381 batch["x"].to(device), 382 batch["sgc_x"].to(device), 383 batch["categorical_covariate_codes"].to(device), 384 ( 385 batch["time_covariate_codes"].to(device) 386 if self.n_time >= 2 387 else None 388 ), 389 not_cov, 390 batch["sample_idx"], 391 True, 392 ) 393 losses.append(float(batch_loss)) 394 # iteration += 1 395 # print(f"Full time{end - start}
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, *kwargs) 143 # get loss and compute gradients 144 with poutine.trace(param_only=True) as param_capture: --> 145 loss = self.loss_and_grads(self.model, self.guide, args, **kwargs) 147 params = set( 148 site["value"].unconstrained() for site in param_capture.trace.nodes.values() 149 ) 151 # actually perform gradient steps 152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs) 138 loss = 0.0 139 # grab a trace from the generator --> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs): 141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle( 142 model_trace, guide_trace 143 ) 144 loss += loss_particle / self.num_particles
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs) 235 else: 236 for i in range(self.num_particles): --> 237 yield self._get_trace(model, guide, args, kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/trace_mean_field_elbo.py:82, in TraceMeanField_ELBO._get_trace(self, model, guide, args, kwargs) 81 def _get_trace(self, model, guide, args, kwargs): ---> 82 model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs) 83 if is_validation_enabled(): 84 _check_mean_field_requirement(model_trace, guide_trace)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs) 52 def _get_trace(self, model, guide, args, kwargs): 53 """ 54 Returns a single trace from the guide, and the model that is run 55 against it. 56 """ ---> 57 model_trace, guide_trace = get_importance_trace( 58 "flat", self.max_plate_nesting, model, guide, args, kwargs 59 ) 60 if is_validation_enabled(): 61 check_if_enumerated(guide_trace)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/infer/enum.py:60, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach) 58 model_trace, guide_trace = unwrapped_guide.get_traces() 59 else: ---> 60 guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace( 61 *args, **kwargs 62 ) 63 if detach: 64 guidetrace.detach()
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:216, in TraceHandler.get_trace(self, *args, kwargs) 208 def get_trace(self, *args, *kwargs) -> Trace: 209 """ 210 :returns: data structure 211 :rtype: pyro.poutine.Trace (...) 214 Calls this poutine and returns its trace instead of the function's return value. 215 """ --> 216 self(args, kwargs) 217 return self.msngr.get_trace()
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.call(self, *args, **kwargs) 196 exc = exc_type("{}\n{}".format(exc_value, shapes)) 197 exc = exc.with_traceback(traceback) --> 198 raise exc from e 199 self.msngr.trace.add_node( 200 "_RETURN", name="_RETURN", type="return", value=ret 201 ) 202 return ret
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py:191, in TraceHandler.call(self, *args, *kwargs) 187 self.msngr.trace.add_node( 188 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 189 ) 190 try: --> 191 ret = self.fn(args, **kwargs) 192 except (ValueError, RuntimeError) as e: 193 exc_type, exc_value, traceback = sys.exc_info()
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/poutine/messenger.py:32, in _context_wrap(context, fn, *args, kwargs) 25 def _context_wrap( 26 context: "Messenger", 27 fn: Callable, 28 *args: Any, 29 *kwargs: Any, 30 ) -> Any: 31 with context: ---> 32 return fn(args, kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/sctm/model.py:704, in spatialLDAModel.guide(self, x, sgc_x, categorical_covariate_code, time_covariate_code, not_cov, sample_idx, mask) 698 z_topic_loc, z_topic_scale = self.encoder( 699 sgc_x, categorical_covariate_code 700 ) 701 with poutine.scale(scale=kl_weight): 702 pyro.sample( 703 "z_topic", --> 704 dist.Normal(z_topic_loc, z_topic_scale).to_event(1), 705 ) 706 else: 707 z_topic_concent = self.encoder(sgc_x, categorical_covariate_code)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/pyro/distributions/distribution.py:26, in DistributionMeta.call(cls, *args, *kwargs) 24 if result is not None: 25 return result ---> 26 return super().call(args, **kwargs)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/normal.py:57, in Normal.init(self, loc, scale, validate_args) 55 else: 56 batch_shape = self.loc.size() ---> 57 super().init(batch_shape, validate_args=validate_args)
File ~/miniconda3/envs/sctm/lib/python3.8/site-packages/torch/distributions/distribution.py:70, in Distribution.init(self, batch_shape, event_shape, validate_args) 68 valid = constraint.check(value) 69 if not valid.all(): ---> 70 raise ValueError( 71 f"Expected parameter {param} " 72 f"({type(value).name} of shape {tuple(value.shape)}) " 73 f"of distribution {repr(self)} " 74 f"to satisfy the constraint {repr(constraint)}, " 75 f"but found invalid values:\n{value}" 76 ) 77 super().init()
ValueError: Expected parameter loc (Tensor of shape (256, 18)) of distribution Normal(loc: torch.Size([256, 18]), scale: torch.Size([256, 18])) to satisfy the constraint Real(), but found invalid values: tensor([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]], gradfn=)
Trace Shapes: dist 18 400 |
Param Sites:
Sample Sites:
sample dist |
value 256 |
caux dist 1 |
value 1 |
delta dist 400 |
value 400 |
bg dist 400 | 4 value 400 | 4 tau dist 18 1 |
value 18 1 |
lambda
value 18 400 |
z_topic_diag dist 1 |
value 1 |
z_topic_lr dist 18 18 |
value 18 18 |
beta_gp_lengthscale dist 18 400 |
value 18 400 |
beta_gp dist 18 400 | 4 value 18 400 | 4 disp dist 400 |
value 400 |