Closed chiukit closed 1 year ago
I have tried 2 times with the 17 flowers dataset: https://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer model = Unet( dim = 64, dim_mults = (1, 2, 4, 8), ) diffusion = GaussianDiffusion( model, image_size = 32, timesteps = 100, # number of steps sampling_timesteps = 25, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) loss_type = 'l1' # L1 or L2 ) trainer = Trainer( diffusion, '/path/folder', train_batch_size = 32, train_lr = 8e-5, train_num_steps = 700000, # total training steps gradient_accumulate_every = 2, # gradient accumulation steps ema_decay = 0.995, # exponential moving average decay amp = True, # turn on mixed precision calculate_fid = True # whether to calculate fid during training ) trainer.train()
After 5 times epochs, it always return ValueError: array must not contain infs or NaNs. The following is the output.
ValueError: array must not contain infs or NaNs
sampling loop time step: 100% 25/25 [00:01<00:00, 24.90it/s] fid_score: 0.2590540265451704 sampling loop time step: 100% 25/25 [00:00<00:00, 27.61it/s] fid_score: 0.28387210408463437 sampling loop time step: 100% 25/25 [00:00<00:00, 27.56it/s] fid_score: 0.2948531148483391 sampling loop time step: 100% 25/25 [00:00<00:00, 25.65it/s] fid_score: 0.259936938675557 sampling loop time step: 100% 25/25 [00:00<00:00, 27.92it/s] fid_score: 0.2560371539941957 Error displaying widget ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ in <module> │ │ │ │ 24 │ calculate_fid = True # whether to calculate fid during training │ │ 25 ) │ │ 26 │ │ ❱ 27 trainer.train() │ │ 28 │ │ │ │ /usr/local/lib/python3.9/dist-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.p │ │ y:1023 in train │ │ │ │ 1020 │ │ │ │ │ │ # whether to calculate fid │ │ 1021 │ │ │ │ │ │ │ │ 1022 │ │ │ │ │ │ if exists(self.inception_v3): │ │ ❱ 1023 │ │ │ │ │ │ │ fid_score = self.fid_score(real_samples = data, fake_samples │ │ 1024 │ │ │ │ │ │ │ accelerator.print(f'fid_score: {fid_score}') │ │ 1025 │ │ │ │ │ │ 1026 │ │ │ │ pbar.update(1) │ │ │ │ /usr/local/lib/python3.9/dist-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.p │ │ y:970 in fid_score │ │ │ │ 967 │ │ m1, s1 = self.calculate_activation_statistics(real_samples) │ │ 968 │ │ m2, s2 = self.calculate_activation_statistics(fake_samples) │ │ 969 │ │ │ │ ❱ 970 │ │ fid_value = calculate_frechet_distance(m1, s1, m2, s2) │ │ 971 │ │ return fid_value │ │ 972 │ │ │ 973 │ def train(self): │ │ │ │ /usr/local/lib/python3.9/dist-packages/pytorch_fid/fid_score.py:188 in │ │ calculate_frechet_distance │ │ │ │ 185 │ diff = mu1 - mu2 │ │ 186 │ │ │ 187 │ # Product might be almost singular │ │ ❱ 188 │ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) │ │ 189 │ if not np.isfinite(covmean).all(): │ │ 190 │ │ msg = ('fid calculation produces singular product; ' │ │ 191 │ │ │ 'adding %s to diagonal of cov estimates') % eps │ │ │ │ /usr/local/lib/python3.9/dist-packages/scipy/linalg/_matfuncs_sqrtm.py:161 in sqrtm │ │ │ │ 158 │ │ [ 1., 4.]]) │ │ 159 │ │ │ 160 │ """ │ │ ❱ 161 │ A = _asarray_validated(A, check_finite=True, as_inexact=True) │ │ 162 │ if len(A.shape) != 2: │ │ 163 │ │ raise ValueError("Non-matrix input to matrix function.") │ │ 164 │ if blocksize < 1: │ │ │ │ /usr/local/lib/python3.9/dist-packages/scipy/_lib/_util.py:287 in _asarray_validated │ │ │ │ 284 │ │ if np.ma.isMaskedArray(a): │ │ 285 │ │ │ raise ValueError('masked arrays are not supported') │ │ 286 │ toarray = np.asarray_chkfinite if check_finite else np.asarray │ │ ❱ 287 │ a = toarray(a) │ │ 288 │ if not objects_ok: │ │ 289 │ │ if a.dtype is np.dtype('O'): │ │ 290 │ │ │ raise ValueError('object arrays are not supported') │ │ │ │ /usr/local/lib/python3.9/dist-packages/numpy/lib/function_base.py:627 in asarray_chkfinite │ │ │ │ 624 │ """ │ │ 625 │ a = asarray(a, dtype=dtype, order=order) │ │ 626 │ if a.dtype.char in typecodes['AllFloat'] and not np.isfinite(a).all(): │ │ ❱ 627 │ │ raise ValueError( │ │ 628 │ │ │ "array must not contain infs or NaNs") │ │ 629 │ return a │ │ 630 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ValueError: array must not contain infs or NaNs
Is there any setup issue here?
@chiukit try turning off the amp
amp
@chiukit what kind of samples do you see before it NaN out? actually those fid scores look not that bad
NaN
@lucidrains It works now when I turn off the amp. Thanks!
I have tried 2 times with the 17 flowers dataset: https://www.robots.ox.ac.uk/~vgg/data/flowers/17/17flowers.tgz
After 5 times epochs, it always return
ValueError: array must not contain infs or NaNs
. The following is the output.Is there any setup issue here?