yang-song / score_sde_pytorch

PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)
https://arxiv.org/abs/2011.13456
Apache License 2.0
1.58k stars 295 forks source link

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) #43

Open dorazhiyuyang opened 1 year ago

dorazhiyuyang commented 1 year ago

RuntimeError Traceback (most recent call last) Cell In[29], line 1 ----> 1 x, n = sampling_fn(score_model) 2 show_samples(x)

File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:407, in get_pc_sampler..pc_sampler(model) 405 vec_t = torch.ones(shape[0], device=t.device) t 406 x, x_mean = corrector_update_fn(x, vec_t, model=model) --> 407 x, x_mean = predictor_update_fn(x, vec_t, model=model) 409 return inverse_scaler(x_mean if denoise else x), sde.N (n_steps + 1)

File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:341, in shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous) 339 else: 340 predictor_obj = predictor(sde, score_fn, probability_flow) --> 341 return predictor_obj.update_fn(x, t)

File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:196, in ReverseDiffusionPredictor.update_fn(self, x, t) 195 def update_fn(self, x, t): --> 196 f, G = self.rsde.discretize(x, t) 197 z = torch.randn_like(x) 198 x_mean = x - f

File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:104, in SDE.reverse..RSDE.discretize(self, x, t) 102 def discretize(self, x, t): 103 """Create discretized iteration rules for the reverse diffusion sampler.""" --> 104 f, G = discretize_fn(x, t) 105 rev_f = f - G[:, None, None, None] * 2 score_fn(x, t) * (0.5 if self.probability_flow else 1.) 106 rev_G = torch.zeros_like(G) if self.probability_flow else G

File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:251, in VESDE.discretize(self, x, t) 248 timestep = (t * (self.N - 1) / self.T).long() 249 sigma = self.discrete_sigmas.to(t.device)[timestep] 250 adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), --> 251 self.discrete_sigmas[timestep - 1].to(t.device)) 252 f = torch.zeros_like(x) 253 G = torch.sqrt(sigma 2 - adjacent_sigma 2)

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

pace577 commented 1 year ago

Changing self.discrete_sigmas[timestep - 1].to(t.device) to self.discrete_sigmas.to(t.device)[timestep - 1] in this line of sde_lib.py seems to fix the problem.