Open Arksyd96 opened 1 year ago
No, that is abnormal. To train CIFAR-10, an 11G VRAM like the 2080 Ti is sufficient. However, if you use a larger model, the VRAM requirements may increase.
Yeah problem fixed. Actually i'm training on 1x128x128 BraTS images and i forgot to put a torch.no_grad(): during reverse process.
However, i still have an issue with the reverse process. During training, the MSE is well optimized, but it only generates noise. Here's my sampling code if you want to give it a look and tell me if its ok :
def q_mean_variance(self, x_0, x_t, t):
posterior_mean = (
self.posterior_mean_c1[t, None, None, None].to(device) * x_0 +
self.posterior_mean_c2[t, None, None, None].to(device) * x_t
)
posterior_log_var = self.posterior_log_var[t, None, None, None]
return posterior_mean, posterior_log_var
def p_mean_variance(self, x_t, t):
model_logvar = torch.log(torch.cat([self.posterior_var[1: 2], self.betas[1:]])).to(device)
model_logvar = model_logvar[t, None, None, None]
eps = self.model(x_t, t.to(device))
x_0 = self.predict_x_start_from_eps(x_t, t, eps)
model_mean, _ = self.q_mean_variance(x_0, x_t, t)
return model_mean, model_logvar
def predict_x_start_from_eps(self, x_t, t, eps):
return (
torch.sqrt(1. - self.alpha_prods[t, None, None, None].to(device)) * x_t +
torch.sqrt(1. / self.alpha_prods[t, None, None, None].to(device) - 1.) * eps
)
def forward(self, x_T):
x_t = x_T
for timestep in reversed(range(self.T)):
t = torch.full((x_T.shape[0],), fill_value=timestep, dtype=torch.long)
mean, logvar = self.p_mean_variance(x_t, t)
if timestep > 0:
noise = torch.randn_like(x_T)
else:
noise = 0
x_t = mean + torch.exp(0.5 * logvar) * noise
x_0 = x_t
return torch.clip(x_0, -1, 1)
Apologies for the delayed response.
To the best of my recollection, you do not need to update the GaussianDiffusionTrainer
and GaussianDiffusionSampler
when training with images of different sizes. These components are capable of adapting to different image dimensions.
However, you will need to modify the model and data-related code, including the UNet
, dataset, and dataloader, to accommodate the new image sizes.
Hello, having issues with memory usage. Is it normal that even with 48Go VRAM i cannot run the reverse process for generation with a small batch of 2 ? What are you specs ?