cvg / pixloc

Back to the Feature: Learning Robust Camera Localization from Pixels to Pose (CVPR 2021)
Apache License 2.0
735 stars 92 forks source link

NAN appear during training #10

Closed loocy3 closed 2 years ago

loocy3 commented 2 years ago

After 21850 training iterates, I got NAN in UNet extracted features. Could you give any advice that where of the source code should I look into?

sarlinpe commented 2 years ago
  1. What dataset are you training with?
  2. Could you try to enable anomaly detection by uncommenting this line? Please then report the entire stack traceback. https://github.com/cvg/pixloc/blob/90f7e968398252e8557b284803ee774cb8d80cd0/pixloc/pixlib/train.py#L207
sarlinpe commented 2 years ago
  1. There could appear NaNs in the solver step if the optimization is too difficult, but this should already be handled by the code.
  2. Did you try to train with a different random seed? Is the NaN always appearing at the same training iteration?
loocy3 commented 2 years ago
  1. There could appear NaNs in the solver step if the optimization is too difficult, but this should already be handled by the code. -> may I know how you handled the case? by 'too few match points' check?
  2. Did you try to train with a different random seed? Is the NaN always appearing at the same training iteration? -> I loaded pertained CMU model, and fine-tune on Kitti data. I did not change the random seed. Nan is not always appearing at the same training iteration, but appearing around 29000~34000 iteration if it recurs.
jmorlana commented 2 years ago

I'm also having these kind of issues. Training in the same MegaDepth dataset with different configurations of U-Net (encoder pretrained on other data, frozen encoder, deleting decoder, etc). All of them lead to NaN at some point during the optimization. I didn´t conclude yet if they come from the optimization or from features directly.

Edit: I did not change the random seed either and the error does not repeat in the same iteration. Seems to appear randomly in the middle of training.

sarlinpe commented 2 years ago

This is concerning; let me dig into it (this will likely take me a few days).

jmorlana commented 2 years ago

[11/02/2021 07:16:05 pixloc INFO] [E 7 | it 2450] loss {total 3.257E+00, reprojection_error/0 9.695E+00, reprojection_error/1 8.376E+00, reprojection_error/2 8.366E+00, reprojection_error 8.366E+00, reprojection_error/init 3.127E+01} [11/02/2021 07:16:06 pixloc.pixlib.models.two_view_refiner WARNING] NaN detected ['error', tensor([ nan, 5.0000e+01, 1.4714e-01, 2.6252e-03, 2.5921e-02, 3.2593e-02], device='cuda:0', grad_fn=), 'loss', tensor([ nan, 0.0000, 0.0490, 0.0009, 0.0086, 0.0109], device='cuda:0', grad_fn=)] [W python_anomaly_mode.cpp:104] Warning: Error detected in PowBackward1. Traceback of forward call that caused the error: File "/home/jmorlana/anaconda3/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/jmorlana/anaconda3/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/jmorlana/pixloc/pixloc/pixlib/train.py", line 391, in main_worker(0, conf, output_dir, args) File "/home/jmorlana/pixloc/pixloc/pixlib/train.py", line 358, in main_worker training(rank, conf, output_dir, args) File "/home/jmorlana/pixloc/pixloc/pixlib/train.py", line 259, in training losses = loss_fn(pred, data) File "/home/jmorlana/pixloc/pixloc/pixlib/models/two_view_refiner.py", line 151, in loss err = reprojection_error(T_opt).clamp(max=self.conf.clamp_error) File "/home/jmorlana/pixloc/pixloc/pixlib/models/two_view_refiner.py", line 133, in reprojection_error err = scaled_barron(1., 2.)(err)[0]/4 File "/home/jmorlana/pixloc/pixloc/pixlib/geometry/losses.py", line 81, in return lambda x: scaled_loss( File "/home/jmorlana/pixloc/pixloc/pixlib/geometry/losses.py", line 18, in scaled_loss loss, loss_d1, loss_d2 = fn(x/a2) File "/home/jmorlana/pixloc/pixloc/pixlib/geometry/losses.py", line 82, in x, lambda y: barron_loss(y, y.new_tensor(a)), c) File "/home/jmorlana/pixloc/pixloc/pixlib/geometry/losses.py", line 59, in barron_loss torch.pow(x / beta_safe + 1., 0.5 * alpha) - 1.) (function _print_stack)

Thank you!

loocy3 commented 2 years ago

Thank you for the analysis. I have reproduced the issue:


[W python_anomaly_mode.cpp:104] Warning: Error detected in MulBackward0. Traceback of forward call that caused the error: File "pixloc/pixlib/train.py", line 417, in main_worker(0, conf, output_dir, args) File "pixloc/pixlib/train.py", line 383, in main_worker training(rank, conf, output_dir, args) File "pixloc/pixlib/train.py", line 281, in training pred = model(data) File ".local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "pixloc/pixloc/pixlib/models/base_model.py", line 106, in forward return self._forward(data) File "pixloc/pixloc/pixlib/models/two_view_refiner.py", line 117, in _forward mask=mask, W_ref_q=W_ref_q)) File ".local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, *kwargs) File "/pixloc/pixloc/pixlib/models/base_model.py", line 106, in forward return self._forward(data) File "pixloc/pixloc/pixlib/models/base_optimizer.py", line 97, in _forward data['cam_q'], data['mask'], data.get('W_ref_q')) File "pixloc/pixloc/pixlib/models/learned_optimizer.py", line 78, in _run delta = optimizerstep(g, H, lambda, mask=~failed) File "pixloc/pixloc/pixlib/geometry/optimization.py", line 18, in optimizer_step diag = H.diagonal(dim1=-2, dim2=-1) lambda_ (function _print_stack) Traceback (most recent call last): File "pixloc/pixlib/train.py", line 417, in main_worker(0, conf, output_dir, args) File "pixloc/pixlib/train.py", line 383, in main_worker training(rank, conf, output_dir, args) File "pixloc/pixlib/train.py", line 292, in training loss.backward() File ".local/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File ".local/lib/python3.7/site-packages/torch/autograd/init.py", line 156, in backward allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.

angiend commented 2 years ago

RuntimeError:Function 'PowBackward1' returned nan values in its 0th output #16

sarlinpe commented 2 years ago

I believe that the issue has been addressed by https://github.com/cvg/pixloc/commit/8937e29baa49e62326e9b9a98766e48420a563fb and https://github.com/cvg/pixloc/commit/0ab0e795a443c67ccb948b6fa375393a5b98c093. Can you please confirm that this helps? I will continue to investigate other sources of instabilities.

angiend commented 2 years ago

i tested the change code ,but get the same error .

sarlinpe commented 2 years ago

@angiend what dataset are you training with? at which iteration does it crash? with what version of PyTorch?

angiend commented 2 years ago

@Skydes i retrain on CMU dataset, crash at "E 65| it 800 "(3000 iter at each epoch),and my pytorch version is 1.9.1

sarlinpe commented 2 years ago

The training has usually fully converged at epoch 20 so this should not prevent reproducing the results. Could give a try to PyTorch 1.7.1? I have tried both 1.7.1 and 1.10.0 and both work fine.

loocy3 commented 2 years ago

Thanks, I have test 3 epochs and I think this issue has been fixed.