lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
8.31k stars 1.03k forks source link

ValueError: array must not contain infs or NaNs #191

Closed chiukit closed 1 year ago

chiukit commented 1 year ago

I have tried 2 times with the 17 flowers dataset: https://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    timesteps = 100,           # number of steps
    sampling_timesteps = 25,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
    loss_type = 'l1'            # L1 or L2
)

trainer = Trainer(
    diffusion,
    '/path/folder',
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
    calculate_fid = True              # whether to calculate fid during training
)

trainer.train()

After 5 times epochs, it always return ValueError: array must not contain infs or NaNs. The following is the output.

sampling loop time step: 100%
25/25 [00:01<00:00, 24.90it/s]
fid_score: 0.2590540265451704
sampling loop time step: 100%
25/25 [00:00<00:00, 27.61it/s]
fid_score: 0.28387210408463437
sampling loop time step: 100%
25/25 [00:00<00:00, 27.56it/s]
fid_score: 0.2948531148483391
sampling loop time step: 100%
25/25 [00:00<00:00, 25.65it/s]
fid_score: 0.259936938675557
sampling loop time step: 100%
25/25 [00:00<00:00, 27.92it/s]
fid_score: 0.2560371539941957
Error displaying widget
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>                                                                                      │
│                                                                                                  │
│   24 │   calculate_fid = True              # whether to calculate fid during training            │
│   25 )                                                                                           │
│   26                                                                                             │
│ ❱ 27 trainer.train()                                                                             │
│   28                                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.p │
│ y:1023 in train                                                                                  │
│                                                                                                  │
│   1020 │   │   │   │   │   │   # whether to calculate fid                                        │
│   1021 │   │   │   │   │   │                                                                     │
│   1022 │   │   │   │   │   │   if exists(self.inception_v3):                                     │
│ ❱ 1023 │   │   │   │   │   │   │   fid_score = self.fid_score(real_samples = data, fake_samples  │
│   1024 │   │   │   │   │   │   │   accelerator.print(f'fid_score: {fid_score}')                  │
│   1025 │   │   │   │                                                                             │
│   1026 │   │   │   │   pbar.update(1)                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.p │
│ y:970 in fid_score                                                                               │
│                                                                                                  │
│    967 │   │   m1, s1 = self.calculate_activation_statistics(real_samples)                       │
│    968 │   │   m2, s2 = self.calculate_activation_statistics(fake_samples)                       │
│    969 │   │                                                                                     │
│ ❱  970 │   │   fid_value = calculate_frechet_distance(m1, s1, m2, s2)                            │
│    971 │   │   return fid_value                                                                  │
│    972 │                                                                                         │
│    973 │   def train(self):                                                                      │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/pytorch_fid/fid_score.py:188 in                           │
│ calculate_frechet_distance                                                                       │
│                                                                                                  │
│   185 │   diff = mu1 - mu2                                                                       │
│   186 │                                                                                          │
│   187 │   # Product might be almost singular                                                     │
│ ❱ 188 │   covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)                              │
│   189 │   if not np.isfinite(covmean).all():                                                     │
│   190 │   │   msg = ('fid calculation produces singular product; '                               │
│   191 │   │   │      'adding %s to diagonal of cov estimates') % eps                             │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/scipy/linalg/_matfuncs_sqrtm.py:161 in sqrtm              │
│                                                                                                  │
│   158 │   │      [ 1.,  4.]])                                                                    │
│   159 │                                                                                          │
│   160 │   """                                                                                    │
│ ❱ 161 │   A = _asarray_validated(A, check_finite=True, as_inexact=True)                          │
│   162 │   if len(A.shape) != 2:                                                                  │
│   163 │   │   raise ValueError("Non-matrix input to matrix function.")                           │
│   164 │   if blocksize < 1:                                                                      │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/scipy/_lib/_util.py:287 in _asarray_validated             │
│                                                                                                  │
│   284 │   │   if np.ma.isMaskedArray(a):                                                         │
│   285 │   │   │   raise ValueError('masked arrays are not supported')                            │
│   286 │   toarray = np.asarray_chkfinite if check_finite else np.asarray                         │
│ ❱ 287 │   a = toarray(a)                                                                         │
│   288 │   if not objects_ok:                                                                     │
│   289 │   │   if a.dtype is np.dtype('O'):                                                       │
│   290 │   │   │   raise ValueError('object arrays are not supported')                            │
│                                                                                                  │
│ /usr/local/lib/python3.9/dist-packages/numpy/lib/function_base.py:627 in asarray_chkfinite       │
│                                                                                                  │
│    624 │   """                                                                                   │
│    625 │   a = asarray(a, dtype=dtype, order=order)                                              │
│    626 │   if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all():                │
│ ❱  627 │   │   raise ValueError(                                                                 │
│    628 │   │   │   "array must not contain infs or NaNs")                                        │
│    629 │   return a                                                                              │
│    630                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: array must not contain infs or NaNs

Is there any setup issue here?

lucidrains commented 1 year ago

@chiukit try turning off the amp

lucidrains commented 1 year ago

@chiukit what kind of samples do you see before it NaN out? actually those fid scores look not that bad

chiukit commented 1 year ago

@lucidrains It works now when I turn off the amp. Thanks!