Open ShaochongJia opened 1 year ago
In the function def discrete_diffusion_predict_fn(), self.device() is called, however the self is not defined in this function. Code snippet here, self.device() is giving the error:
if predict_x0: init_state = SamplingState(x, x, torch.tensor([num_steps], device=self.device)) else: init_state = SamplingState(x, None, torch.tensor([num_steps], device=self.device))
I tried to pass in device as function arg and manipulate the devices of variables here and didn't make it.
Please provide an updated discrete_diffusion_predict_fn() that addresses this device inconsistency if possible.
I have the same issue. Could you solve it?
In the function def discrete_diffusion_predict_fn(), self.device() is called, however the self is not defined in this function. Code snippet here, self.device() is giving the error:
I tried to pass in device as function arg and manipulate the devices of variables here and didn't make it.
Please provide an updated discrete_diffusion_predict_fn() that addresses this device inconsistency if possible.