lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
7.48k stars 958 forks source link

Training on Celeba-hq #300

Open moonnnpie opened 3 months ago

moonnnpie commented 3 months ago

Thanks for your work.

I do training on celeba-hq dataset, and after 110k steps, I find that the images seem to have color problem, is there something wrong i need to do with datasets? 64a5ac5ea03fc669048ef68a3db224f

follows are my settings

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet( dim = 64, dim_mults = (1, 2, 4, 8) ).cuda()

diffusion = GaussianDiffusion( model, image_size = 256, timesteps = 1000, # number of steps sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) loss_type = 'l1' # L1 or L2 ).cuda()

trainer = Trainer( diffusion, '/mnt/shared/deepfake/CelebA-HQ/train', train_batch_size = 32, train_lr = 2e-5, train_num_steps = 7000000, # total training steps gradient_accumulate_every = 2, # gradient accumulation steps ema_decay = 0.995, # exponential moving average decay amp = False, # turn on mixed precision calculate_fid = True # whether to calculate fid during training )

Zhangzeyu7 commented 3 months ago

you can try 'amp = True'

moonnnpie commented 3 months ago

you can try 'amp = True'

thanks but i tried and found that the images turns out total green

nilsleh commented 3 months ago

I am also currently trying to get some reasonable results for the FFHQ dataset, and also want to try Celeba-HQ.

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
from torchvision.transforms.functional import pil_to_tensor
from PIL import Image

class FFHQDataset(VisionDataset):
    def __init__(self, root: str):
        super().__init__(root)

        self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.fpaths) > 0, "File list is empty. Check the root."

    def __len__(self):
        return len(self.fpaths)

    def __getitem__(self, index: int):
        fpath = self.fpaths[index]
        img = Image.open(fpath).convert('RGB')
        # normalize to [0, 1] range
        img = pil_to_tensor(img) / 255.
        return img

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8, 16, 32),
    flash_attn = True
)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,           # number of steps
    sampling_timesteps=500
)

dataset = FFHQDataset(root="/mnt/SSD2/nils/ocean_bench_exps/diffusion/data/ffhq/thumbnails128x128")
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True, pin_memory = True, num_workers = 12)

trainer = Trainer(
    diffusion,
    'path/to/your/images',
    train_lr = 8e-5,
    train_num_steps = 50000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = False,
    num_samples=16,
    save_and_sample_every=10000,
    dl = dataloader,
)

trainer.train()

I gave the trainer a dataloader argument, because I wanted control over different dataloaders and their configurations, so effectively, just replaced the dataset and dl code block to just take the dl argument from the Trainer. The following are some samples, loss is around 0.02-0.03.

It was mentioned here that amp=False helps, but I have tried both and there is no significant change.

Screenshot from 2024-03-22 12-56-48

Overall I would also expect better results, so I am wondering if people have experience and suggestions?

Edit: Training for longer seems to improve results a bit (300,000 training steps) sample-30

nilsleh commented 2 months ago

These are results on the CelebHQ datset:

sample-29

szh404 commented 3 weeks ago

you can try 'amp = True'

thanks but i tried and found that the images turns out total green

Have you solved this problem? I meet this problem recently.