HJ-harry / score-MRI

Apache License 2.0
143 stars 23 forks source link

A bug in device during sampling #18

Open Z7Gao opened 1 year ago

Z7Gao commented 1 year ago

Hi, I encountered a bug when running the real image sampling script.

File "xxxxx/score-MRI/sde_lib.py", line 156, in discretize self.discrete_sigmas[timestep - 1].to(t.device)) RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

a simple fix can solve this problem

adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
                                 self.discrete_sigmas[timestep.item() - 1].to(t.device))