explainingai-code / DDPM-Pytorch

This repo implements Denoising Diffusion Probabilistic Models (DDPM) in Pytorch
43 stars 6 forks source link

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) #5

Closed exponentialXP closed 4 months ago

exponentialXP commented 4 months ago

I get this error when using this repository. This seems to fix it, but it's probably not the most efficient way to do it:

def add_noise(self, original, noise, t):
        original_shape = original.shape
        batch_size = original_shape[0]

        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod[t.cpu()].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t.cpu()].reshape(batch_size)

        for _ in range(len(original_shape)-1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)

        for _ in range(len(original_shape)-1):
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

        return (sqrt_alpha_cum_prod.to(original.device) * original + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)

    def backward(self, xt, noise_pred, t):
        x0 = (xt - (self.sqrt_one_minus_alpha_cum_prod[t.cpu()].to(noise_pred.device) * noise_pred)) / torch.sqrt(self.alpha_cum_prod[t.cpu()].to(noise_pred.device))
        x0 = torch.clamp(x0, -1., 1.)

        mean = xt - ((self.betas[t.cpu()]).to(noise_pred.device)*noise_pred).to(noise_pred.device)/(self.sqrt_one_minus_alpha_cum_prod[t.cpu()]).to(noise_pred.device)
        mean = mean / torch.sqrt(self.alphas[t.cpu()].to(noise_pred.device))

        if t == 0:
            return mean, mean
        else:
            variance = (1-self.alpha_cum_prod[(t-1).cpu()]).to(noise_pred.device) / (1. - self.alpha_cum_prod[t.cpu()]).to(noise_pred.device)
            variance = variance * self.betas[t.cpu()].to(noise_pred.device)
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)

            return mean + sigma*z, x0
explainingai-code commented 4 months ago

Hello , Just wanted to check if this you are getting this error after installing the required versions as mentioned in requirements.txt ? Specifically What pytorch version do you have ?

exponentialXP commented 4 months ago

Hello , Just wanted to check if this you are getting this error after installing the required versions as mentioned in requirements.txt ? Specifically What pytorch version do you have ?

I have PyTorch version 2.2.0+cu121 and have all the libraries installed

explainingai-code commented 4 months ago

Ahh got it yeah so thats the cause of the issue, the different pytorch version. This should not be happening with the pytorch version which is there in requirements.txt. The LDM repo moves the alphas to the same device before indexing with time - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/scheduler/linear_noise_scheduler.py#L34 Will do similar change here which is basically doing the same thing as what you are doing here.

The better change would be to move it by passing device in the scheduler initialization itself, but will take that up later.

exponentialXP commented 4 months ago

Ahh got it yeah so thats the cause of the issue, the different pytorch version. This should not be happening with the pytorch version which is there in requirements.txt. The LDM repo moves the alphas to the same device before indexing with time - https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/scheduler/linear_noise_scheduler.py#L34 Will do similar change here which is basically doing the same thing as what you are doing here.

The better change would be to move it by passing device in the scheduler initialization itself, but will take that up later.

Ahhh, yeah that's way better. Also in line 58 it says return mean, mean instead of return mean, x0. Also should I close the issue now?

explainingai-code commented 4 months ago

Yes. Thank you @exponentialXP . Have pushed the changes for both of these. Feel free to re-open in case you find any issue with this.