Mid-Push / Decent

Unpaired Image Translation, Neurips2022
Other
25 stars 3 forks source link

Error while training the network #4

Open Manjuphoenix opened 1 year ago

Manjuphoenix commented 1 year ago

Minimal steps for error replication:

Training on Nvida A6000 GPU OS: Ubuntu 20.04 Cuda version: 12.0

Pip list: absl-py 1.4.0 cachetools 4.2.4 certifi 2021.5.30 cffi 1.14.6 charset-normalizer 2.0.12 cycler 0.11.0 dataclasses 0.8 decorator 4.4.2 dominate 2.4.0 google-auth 2.18.1 google-auth-oauthlib 0.4.6 GPUtil 1.4.0 grpcio 1.48.2 idna 3.4 importlib-metadata 4.8.3 importlib-resources 5.4.0 jsonpatch 1.32 jsonpointer 2.3 kiwisolver 1.3.1 Markdown 3.3.7 matplotlib 3.3.4 mkl-fft 1.3.0 mkl-random 1.1.1 mkl-service 2.3.0 networkx 2.5.1 nflows 0.14 numpy 1.16.4 oauthlib 3.2.2 opencv-python 4.7.0.72 packaging 21.3 Pillow 8.4.0 pip 21.2.2 protobuf 3.19.6 pyasn1 0.5.0 pyasn1-modules 0.3.0 pycparser 2.21 pyparsing 3.0.9 python-dateutil 2.8.2 pyzmq 25.0.2 requests 2.27.1 requests-oauthlib 1.3.1 rsa 4.9 scipy 1.5.2 setuptools 58.0.4 six 1.16.0 tensorboard 2.10.1 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.1 torch 1.4.0 torch-fidelity 0.3.0 torchfile 0.1.0 torchvision 0.5.0 tornado 6.1 tqdm 4.64.1 typing_extensions 4.1.1 urllib3 1.26.16 visdom 0.2.4 websocket-client 1.3.1 Werkzeug 2.0.3 wheel 0.37.1 zipp 3.6.0

Dataset used: horse2zebra Batch size: 24, single GPU only

Command for training: python train.py --dataroot=datasets/horse2zebra/ --gpu=1 --batch_size=24

Error obtained: learning rate = 0.0002000
(epoch: 10, iters: 120, time: 0.141, data: 0.008) G: nan G_GAN: nan D_real: nan D_fake: nan idt: nan var: nan nll_A: nan nll_B: nan exp_A: nan exp_B: nan (epoch: 10, iters: 720, time: 0.139, data: 0.010) G: nan G_GAN: nan D_real: nan D_fake: nan idt: nan var: nan nll_A: nan nll_B: nan exp_A: nan exp_B: nan (epoch: 10, iters: 1320, time: 0.138, data: 0.008) G: nan G_GAN: nan D_real: nan D_fake: nan idt: nan var: nan nll_A: nan nll_B: nan exp_A: nan exp_B: nan [] start evaluation!
datasets/horse2zebra/testB
./checkpoints/debug/horse2zebra_AtoB/var0.01_np256_nb1_nl0_nd10_lr0.001_ema0.998_var_single/fake
Traceback (most recent call last):
File "train.py", line 82, in
eval_dict = eval_loader(model, test_loader_A, test_loader_B, opt.run_dir, opt)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(
args, kwargs)
File "/home/user/manjunath/GAN/Decent/models/utils.py", line 76, in eval_loader
return eval_loader_one(model, test_loader_a, test_loader_b, output_directory, opt)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(*args, *kwargs)
File "/home/user/manjunath/GAN/Decent/models/utils.py", line 97, in eval_loader_one
eval_dict = eval_method_one(real_dir, fake_dir, opt)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(
args,
kwargs)
File "/home/user/manjunath/GAN/Decent/models/utils.py", line 115, in eval_method_one
metric_dict_AB = torch_fidelity.calculate_metrics(input1=realB_path, input2=fakeB_path, **eval_args) File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch_fidelity/metrics.py", line 258, in calculate_metrics metric_fid = fid_statistics_to_metric(fid_stats_1, fid_stats_2, get_kwarg('verbose', kwargs))
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/torch_fidelity/metric_fid.py", line 47, in fid_statistics_tometric covmean, = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/scipy/linalg/_matfuncs_sqrtm.py", line 161, in sqrtm A = _asarray_validated(A, check_finite=True, as_inexact=True)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/scipy/_lib/_util.py", line 263, in _asarray_validated a = toarray(a)
File "/home/user/anaconda3/envs/decent/lib/python3.6/site-packages/numpy/lib/function_base.py", line 498, in asarray_chkfinite "array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs

Noticed: learning rate is not changing after each epoch

Saddy21 commented 1 year ago

Same issue!!