w86763777 / pytorch-ddpm

Unofficial PyTorch implementation of Denoising Diffusion Probabilistic Models
Do What The F*ck You Want To Public License
506 stars 62 forks source link

About memory usage #13

Open Arksyd96 opened 1 year ago

Arksyd96 commented 1 year ago

Hello, having issues with memory usage. Is it normal that even with 48Go VRAM i cannot run the reverse process for generation with a small batch of 2 ? What are you specs ?

w86763777 commented 1 year ago

No, that is abnormal. To train CIFAR-10, an 11G VRAM like the 2080 Ti is sufficient. However, if you use a larger model, the VRAM requirements may increase.

Arksyd96 commented 1 year ago

Yeah problem fixed. Actually i'm training on 1x128x128 BraTS images and i forgot to put a torch.no_grad(): during reverse process.

However, i still have an issue with the reverse process. During training, the MSE is well optimized, but it only generates noise. Here's my sampling code if you want to give it a look and tell me if its ok :

    def q_mean_variance(self, x_0, x_t, t):
        posterior_mean = (
            self.posterior_mean_c1[t, None, None, None].to(device) * x_0 + 
            self.posterior_mean_c2[t, None, None, None].to(device) * x_t
        )
        posterior_log_var = self.posterior_log_var[t, None, None, None]
        return posterior_mean, posterior_log_var

    def p_mean_variance(self, x_t, t):
        model_logvar = torch.log(torch.cat([self.posterior_var[1: 2], self.betas[1:]])).to(device)
        model_logvar = model_logvar[t, None, None, None]

        eps = self.model(x_t, t.to(device))
        x_0 = self.predict_x_start_from_eps(x_t, t, eps)
        model_mean, _ = self.q_mean_variance(x_0, x_t, t)

        return model_mean, model_logvar

    def predict_x_start_from_eps(self, x_t, t, eps):
        return (
            torch.sqrt(1. - self.alpha_prods[t, None, None, None].to(device)) * x_t +
            torch.sqrt(1. / self.alpha_prods[t, None, None, None].to(device) - 1.) * eps
        )

    def forward(self, x_T):
        x_t = x_T
        for timestep in reversed(range(self.T)):
            t = torch.full((x_T.shape[0],), fill_value=timestep, dtype=torch.long)
            mean, logvar = self.p_mean_variance(x_t, t)
            if timestep > 0:
                noise = torch.randn_like(x_T)
            else:
                noise = 0
            x_t = mean + torch.exp(0.5 * logvar) * noise
        x_0 = x_t
        return torch.clip(x_0, -1, 1)
w86763777 commented 1 year ago

Apologies for the delayed response.

To the best of my recollection, you do not need to update the GaussianDiffusionTrainer and GaussianDiffusionSampler when training with images of different sizes. These components are capable of adapting to different image dimensions.

However, you will need to modify the model and data-related code, including the UNet, dataset, and dataloader, to accommodate the new image sizes.