lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
8.28k stars 1.02k forks source link

Problem When Computing FID on One-Channel Images #188

Closed Derick317 closed 1 year ago

Derick317 commented 1 year ago

I would like to train diffusion model on MNIST dataset. I have set channels = 1. During training, the program seemed correct, but something went wrong when I sampled images and caculated FID:

D:\projects\diffusion_model_pytorch>python test_mnist.py
sampling loop time step: 100%|████████████████████████████████████████████████████████████████████| 250/250 [01:33<00:00,  2.67it/s]
loss: 0.3010:   0%|                                                                         | 49/100000 [04:32<154:19:33,  5.56s/it] 
Traceback (most recent call last):
  File "D:\projects\diffusion_model_pytorch\test_mnist.py", line 32, in <module>
    trainer.train()
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\denoising_diffusion_pytorch\denoising_diffusion_pytorch.py", line 1002, in train
    fid_score = self.fid_score(real_samples = data, fake_samples = all_images)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\denoising_diffusion_pytorch\denoising_diffusion_pytorch.py", line 946, in fid_score
    m1, s1 = self.calculate_activation_statistics(real_samples)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torchvision\models\inception.py", line 405, in forward    x = self.conv(x)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[25, 1, 299, 299] to have 3 channels, but got 1 channels instead

My code is:

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

if __name__ == "__main__":
    model = Unet(
        dim = 64,
        dim_mults = (1, 2),
        channels = 1
    )

    diffusion = GaussianDiffusion(
        model,
        image_size = 28,
        timesteps = 1000,           # number of steps
        sampling_timesteps = 250,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
        loss_type = 'l1'            # L1 or L2
    )

    trainer = Trainer(
        diffusion,
        './mnist_test_set',
        train_batch_size = 32,
        train_lr = 8e-5,
        train_num_steps = 100000,         # total training steps
        gradient_accumulate_every = 2,    # gradient accumulation steps
        save_and_sample_every = 50,
        ema_decay = 0.995,                # exponential moving average decay
        amp = False,                      # turn on mixed precision
        calculate_fid = True              # whether to calculate fid during training
    )

    trainer.train()

Besides, I wonder whether it will generate a tensorboad log file at every checkpoint?

lhaippp commented 1 year ago

In my opinion, calculating the fid requires 3 channel input, normally i.e. the RGB input In your case, is it possible to convert the gray image to 3 channel before fed into self.fid_score? such as

# data [N, 1, H, W] ->  [N, 3, H, W]
data = data.repeat(1, 3, 1, 1)
Echo-jyt commented 1 year ago

I would like to train diffusion model on MNIST dataset. I have set channels = 1. During training, the program seemed correct, but something went wrong when I sampled images and caculated FID:

D:\projects\diffusion_model_pytorch>python test_mnist.py
sampling loop time step: 100%|████████████████████████████████████████████████████████████████████| 250/250 [01:33<00:00,  2.67it/s]
loss: 0.3010:   0%|                                                                         | 49/100000 [04:32<154:19:33,  5.56s/it] 
Traceback (most recent call last):
  File "D:\projects\diffusion_model_pytorch\test_mnist.py", line 32, in <module>
    trainer.train()
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\denoising_diffusion_pytorch\denoising_diffusion_pytorch.py", line 1002, in train
    fid_score = self.fid_score(real_samples = data, fake_samples = all_images)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\denoising_diffusion_pytorch\denoising_diffusion_pytorch.py", line 946, in fid_score
    m1, s1 = self.calculate_activation_statistics(real_samples)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torchvision\models\inception.py", line 405, in forward    x = self.conv(x)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[25, 1, 299, 299] to have 3 channels, but got 1 channels instead

My code is:

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

if __name__ == "__main__":
    model = Unet(
        dim = 64,
        dim_mults = (1, 2),
        channels = 1
    )

    diffusion = GaussianDiffusion(
        model,
        image_size = 28,
        timesteps = 1000,           # number of steps
        sampling_timesteps = 250,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
        loss_type = 'l1'            # L1 or L2
    )

    trainer = Trainer(
        diffusion,
        './mnist_test_set',
        train_batch_size = 32,
        train_lr = 8e-5,
        train_num_steps = 100000,         # total training steps
        gradient_accumulate_every = 2,    # gradient accumulation steps
        save_and_sample_every = 50,
        ema_decay = 0.995,                # exponential moving average decay
        amp = False,                      # turn on mixed precision
        calculate_fid = True              # whether to calculate fid during training
    )

    trainer.train()

Besides, I wonder whether it will generate a tensorboad log file at every checkpoint?

I would like to train diffusion model on MNIST dataset. I have set channels = 1. During training, the program seemed correct, but something went wrong when I sampled images and caculated FID:

D:\projects\diffusion_model_pytorch>python test_mnist.py
sampling loop time step: 100%|████████████████████████████████████████████████████████████████████| 250/250 [01:33<00:00,  2.67it/s]
loss: 0.3010:   0%|                                                                         | 49/100000 [04:32<154:19:33,  5.56s/it] 
Traceback (most recent call last):
  File "D:\projects\diffusion_model_pytorch\test_mnist.py", line 32, in <module>
    trainer.train()
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\denoising_diffusion_pytorch\denoising_diffusion_pytorch.py", line 1002, in train
    fid_score = self.fid_score(real_samples = data, fake_samples = all_images)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\denoising_diffusion_pytorch\denoising_diffusion_pytorch.py", line 946, in fid_score
    m1, s1 = self.calculate_activation_statistics(real_samples)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torchvision\models\inception.py", line 405, in forward    x = self.conv(x)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\envs\diffusion\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[25, 1, 299, 299] to have 3 channels, but got 1 channels instead

My code is:

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

if __name__ == "__main__":
    model = Unet(
        dim = 64,
        dim_mults = (1, 2),
        channels = 1
    )

    diffusion = GaussianDiffusion(
        model,
        image_size = 28,
        timesteps = 1000,           # number of steps
        sampling_timesteps = 250,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
        loss_type = 'l1'            # L1 or L2
    )

    trainer = Trainer(
        diffusion,
        './mnist_test_set',
        train_batch_size = 32,
        train_lr = 8e-5,
        train_num_steps = 100000,         # total training steps
        gradient_accumulate_every = 2,    # gradient accumulation steps
        save_and_sample_every = 50,
        ema_decay = 0.995,                # exponential moving average decay
        amp = False,                      # turn on mixed precision
        calculate_fid = True              # whether to calculate fid during training
    )

    trainer.train()

Besides, I wonder whether it will generate a tensorboad log file at every checkpoint?

Hello friend. I am also training on the minist dataset, do you know how to generate new images

lucidrains commented 1 year ago

@Derick317 hey Deming, i incorporated Li's suggestion

do you want to see if 1.4.6 works?

Derick317 commented 1 year ago

@Derick317 hey Deming, i incorporated Li's suggestion

do you want to see if 1.4.6 works?

@lucidrains Hi, since I downloaded your code by pip, I wonder how I can update your code?

Derick317 commented 1 year ago

@Derick317 hey Deming, i incorporated Li's suggestion do you want to see if 1.4.6 works?

@lucidrains Hi, since I downloaded your code by pip, I wonder how I can update your code?

I managed to update the repository to 1.5.3. It works!