Open monniert opened 1 year ago
I'm able to replicate this error. We would also see this issue when we trained nerfacto for a large number of iterations (>50k). Seems like it appears early for this scene. The issue is likely due to precision / numerical stability, however we have yet to trace it down.
So Ive been debugging this issue on my side - and have been able to reproduce / remove the following error through the changes here - Though im not really sure why scripting would cause an issue here https://github.com/nerfstudio-project/nerfstudio/pull/910
@tancik I have the same issue with my scene , when training nerfacto model when training for large number of iterations. So it is not scene specific . Will try fp32.
@tancik Looks like the reason is not in a precision. When training with mixed_precision = False I got Nan much faster and worse quality during training for some reason.
@snarb We are using TinyCudaNN for our network and encoding, this is implemented in half-precision. The mixed_precision=True
enables gradient scaling (https://pytorch.org/docs/stable/amp.html#gradient-scaling) which we find necessary when using this library. Are you using the latest git branch, or pip package? The pip package doesn't have this fix yet.
@tancik I am using the repository main branch, but not recent. Latest commit Date: 11/2/2022 10:17:45 PM 121771d126a1c90cf39013d8016e16effbfe8efc.
I have an update. My scenario: I have trained for 1h+ on A100. Then got this error. I have tried 2 times to resume training from the latest checkpoint but got the same error in 10 mins. Then I enabled anomaly detection using
with torch.autograd.set_detect_anomaly(True):
before the entry point (it detects nan values in the graph). And it magically trained for 20 h without errors. It seemed like anomaly detection has fixed the problem :) , but finally, it crashed. Looks like the NaN is from the camera pose optimizer. Here is the stack trace:
/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
File "/home/ubuntu/.pycharm_helpers/pydev/pydevd.py", line 2173, in <module>
main()
File "/home/ubuntu/.pycharm_helpers/pydev/pydevd.py", line 2164, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "/home/ubuntu/.pycharm_helpers/pydev/pydevd.py", line 1476, in run
return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
File "/home/ubuntu/.pycharm_helpers/pydev/pydevd.py", line 1483, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/home/ubuntu/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 258, in <module>
entrypoint()
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 248, in entrypoint
main(
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 234, in main
launch(
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 173, in launch
main_func(local_rank=0, world_size=world_size, config=config)
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 88, in train_loop
trainer.train()
File "/home/ubuntu/repos/nerfstudio/nerfstudio/engine/trainer.py", line 146, in train
loss, loss_dict, metrics_dict = self.train_iteration(step)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/utils/profiler.py", line 43, in wrapper
ret = func(*args, **kwargs)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/engine/trainer.py", line 303, in train_iteration
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/utils/profiler.py", line 43, in wrapper
ret = func(*args, **kwargs)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 248, in get_train_loss_dict
ray_bundle, batch = self.datamanager.next_train(step)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/data/datamanagers.py", line 367, in next_train
ray_bundle = self.train_ray_generator(ray_indices)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/model_components/ray_generators.py", line 53, in forward
camera_opt_to_camera = self.pose_optimizer(c)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/cameras/camera_optimizers.py", line 116, in forward
outputs.append(exp_map_SO3xR3(self.pose_adjustment[indices, :]))
File "/home/ubuntu/repos/nerfstudio/nerfstudio/cameras/lie_groups.py", line 51, in exp_map_SO3xR3
fac1[:, None, None] * skews
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:102.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Printing profiling stats, from longest to shortest duration in seconds
VanillaPipeline.get_eval_image_metrics_and_images: 2.0984
VanillaPipeline.get_eval_loss_dict: 0.2807
Trainer.train_iteration: 0.2053
VanillaPipeline.get_train_loss_dict: 0.1815
Trainer.eval_iteration: 0.0047
Traceback (most recent call last):
python-BaseException
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 258, in <module>
entrypoint()
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 248, in entrypoint
main(
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 234, in main
launch(
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 173, in launch
main_func(local_rank=0, world_size=world_size, config=config)
File "/home/ubuntu/repos/nerfstudio/scripts/train.py", line 88, in train_loop
trainer.train()
File "/home/ubuntu/repos/nerfstudio/nerfstudio/engine/trainer.py", line 146, in train
loss, loss_dict, metrics_dict = self.train_iteration(step)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/utils/profiler.py", line 43, in wrapper
ret = func(*args, **kwargs)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/engine/trainer.py", line 305, in train_iteration
self.grad_scaler.scale(loss).backward() # type: ignore
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/functorch/_src/monkey_patching.py", line 77, in _backward
return _old_backward(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'MulBackward0' returned nan values in its 0th output.
I think is fixed in the latest version. If you don't want to update, try making the change in this PR https://github.com/nerfstudio-project/nerfstudio/pull/910
@tancik With the last changes still got None, this time during training on the Lego scene with the default number of steps with a disabled collider, camera optimizer, and average appearance embeddings.
This time Nan was in a different place:
File "/home/ubuntu/repos/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 298, in get_eval_image_metrics_and_images
metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch)
File "/home/ubuntu/repos/nerfstudio/nerfstudio/models/nerfacto.py", line 283, in get_image_metrics_and_images
lpips = self.lpips(image, rgb)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/metric.py", line 245, in forward
self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/metric.py", line 309, in _forward_reduce_state_update
self.update(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/metric.py", line 395, in wrapped_func
update(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/image/lpip.py", line 137, in update
raise ValueError(
ValueError: Expected both input arguments to be normalized tensors with shape [N, 3, H, W]. Got input with shape torch.Size([1, 3, 800, 800]) and torch.Size([1, 3, 800, 800]) and values in range [tensor(0., device='cuda:0'), tensor(1., device='cuda:0')] and [tensor(nan, device='cuda:0'), tensor(nan, device='cuda:0')] when all values areexpected to be in the [-1, 1] range.
Which model are you running? Are you seeing this on all scenes, or just the blender scenes?
Before I saw this on my own real scene, now on Lego. Looks like some randomness involved, on most of trials I do not see errors. Will try anomaly detection again. Maybe it is other source of Nan this time. I am running Nerfacto.
For lego, try also adding the following arguments --pipeline.model.near-plane 2. --pipeline.model.far-plane 6.
Thanks. I have removed the
self.collider = NearFarCollider(near_plane=self.config.near_plane, far_plane=self.config.far_plane)
line in the nerfactor model to disable the collider for test. Results look similar without it and with --pipeline.model.near-plane 2. --pipeline.model.far-plane 6.
yes it is fixed on redwoods2, closing the issue, thanks!
@tancik With the last changes still got None, this time during training on the Lego scene with the default number of steps with a disabled collider, camera optimizer, and average appearance embeddings.
This time Nan was in a different place:
File "/home/ubuntu/repos/nerfstudio/nerfstudio/pipelines/base_pipeline.py", line 298, in get_eval_image_metrics_and_images metrics_dict, images_dict = self.model.get_image_metrics_and_images(outputs, batch) File "/home/ubuntu/repos/nerfstudio/nerfstudio/models/nerfacto.py", line 283, in get_image_metrics_and_images lpips = self.lpips(image, rgb) File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/metric.py", line 245, in forward self._forward_cache = self._forward_reduce_state_update(*args, **kwargs) File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/metric.py", line 309, in _forward_reduce_state_update self.update(*args, **kwargs) File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/metric.py", line 395, in wrapped_func update(*args, **kwargs) File "/home/ubuntu/anaconda3/envs/nerfstudio/lib/python3.8/site-packages/torchmetrics/image/lpip.py", line 137, in update raise ValueError( ValueError: Expected both input arguments to be normalized tensors with shape [N, 3, H, W]. Got input with shape torch.Size([1, 3, 800, 800]) and torch.Size([1, 3, 800, 800]) and values in range [tensor(0., device='cuda:0'), tensor(1., device='cuda:0')] and [tensor(nan, device='cuda:0'), tensor(nan, device='cuda:0')] when all values areexpected to be in the [-1, 1] range.
@snarb I met the same error during training process. Do you have any solution to it?
@Madaoer @tancik
density_before_activation.max()
Out[26]: tensor(36.6562, device='cuda:0', dtype=torch.float16)
density.max()
Out[27]: tensor(8.3101e+15, device='cuda:0')
8.3101e+15
value is just truncated from nan by trunc_exp
https://github.com/nerfstudio-project/nerfstudio/blob/d767cd50118a1abe0167a6316fc565685eca1926/nerfstudio/field_components/activations.py#L37
and next mpl_head returns nan among returned rbg values https://github.com/nerfstudio-project/nerfstudio/blob/d767cd50118a1abe0167a6316fc565685eca1926/nerfstudio/fields/instant_ngp_field.py#L174
@snarb Do you have some data that you can share that reliably produces this error?
Can you try changing the x.clamp(-15,15)
to smaller values like x.clamp(-12,12)
In my case, this error can be sovled by turn off the camera pose optimize, but the image generated by nerfacto will be very bad. I have made the change in this PR https://github.com/nerfstudio-project/nerfstudio/pull/910
@tancik as for reproductions, I think that it is enough training nerfacto or instant ngp on Lego with a bigger number of steps, after ~ 1 hour of training. Now I am training on a dynamic scene with a modified code. So can't share it simply.
I have tried to make x.clamp(-11,11)
and other things like clipping gradient by a norm. That did not work. But found things that have fixed this issue for me in all cases.
I have observed that density embedding max value before going to nan has outliers with big values from time to time https://wandb.ai/brans/nerfstudio-project/reports/density_embedding_max-22-11-24-17-17-00---VmlldzozMDMwNTAx?accessToken=anapz2ezdn0wo7z7tfuq5xsj8hfb12zadv7k90arhles66on4ae93cmn25kmas7f and looks like these outliers finally lead to this problem. So the solution that worked was to limit the range of density embeddings values by applyng sigmoid to the base_mlp_out.
So the https://github.com/nerfstudio-project/nerfstudio/blob/b8f85fb603e426309697f7590db3e2c34b9a0d66/nerfstudio/fields/instant_ngp_field.py#L149
becomes
return density, torch.sigmoid(base_mlp_out)
this does not affect the quality for my scenes and removes completely the problem with nan values for me.
@snarb Thank you snarb. In my setting, adding torch.sigmoid in field function does make training process more stable. However, it seems to harm the image quality in eval set, tending to generate more artifacts and wrong geometry shape in image.
@Madaoer as an experiment you can probably also try to clip the values of base_mlp_out , for example by the base_mlp_out .clamp(-1000, 1000)
. On my synthetic scenes I actualy see faster training/better metrics with sigmoid so far.
When playing with tiny-cuda-nn-based models I've found that applying a small weight decay term (1e-6) sometimes helped with training stability, but at least in my case this was somewhat scene specific
I still find nans when using custom data, particularly when adding additional MLP heads (an example would be the semantic nerfw approach).
Wondering if a gradient clipping solution has been considered/tested.
@krrish94 I have tried. Didn't help
@tancik as for reproductions, I think that it is enough training nerfacto or instant ngp on Lego with a bigger number of steps, after ~ 1 hour of training. Now I am training on a dynamic scene with a modified code. So can't share it simply.
I have tried to make
x.clamp(-11,11)
and other things like clipping gradient by a norm. That did not work. But found things that have fixed this issue for me in all cases.I have observed that density embedding max value before going to nan has outliers with big values from time to time https://wandb.ai/brans/nerfstudio-project/reports/density_embedding_max-22-11-24-17-17-00---VmlldzozMDMwNTAx?accessToken=anapz2ezdn0wo7z7tfuq5xsj8hfb12zadv7k90arhles66on4ae93cmn25kmas7f and looks like these outliers finally lead to this problem. So the solution that worked was to limit the range of density embeddings values by applyng sigmoid to the base_mlp_out.
So the
becomes
return density, torch.sigmoid(base_mlp_out)
this does not affect the quality for my scenes and removes completely the problem with nan values for me.
Run into the similar problem of NaN and applied sigmoid solution when working with a semantic head, it seems drastically lower the quality and hinder the convergence.
p.s. tried the x.clamp(-12,12)
too, didn't work either.
Not sure, maybe I was wrong if the issue is causing my problem.
When the nerfacto
or instant-ngp
(maybe others) runs with longer iterations, either the whole scene gets "black" or full of "noise".
Here is the screen when I train with instant-ngp
, for nerfacto
was just "black". If the issue was true, then I couldnt increase the epochs, then the issue was severe.
My current solution is tuning the regularization via
--optimizers.fields.optimizer.weight-decay 1e-6 or less
--optimizers.proposal-networks.optimizer.weight-decay 1e-6 or less
The training can be continued at some degrees, but I haven't tried more than 200_000 iterations.
My guess is that the learning goes explosion caused the issue.
Good find @XinyueZ. Do you find that these weight-decays effect the training speed or quality in the first 30K iters? If now, we should consider changing making these the default.
https://github.com/nerfstudio-project/nerfstudio/issues/873#issuecomment-1407712372 So far I have not found the diff in speed. Of cos, the quality of the model itself could be affected by the tuning of regularization theoretically, but it's quite normal in the DL field. At least it was fine for me. I have trained my model 500_000 iterations, BUT, I trained 100_000 each time from previous checkpoints. I think that there won't be a problem when it trains 500_000 in a row. The early "black" cases happened even when I did 100_000 for each piece.
Hopefully, my experience could help you guys. @tancik
For those still running into this issue - it might be worth logging gradients and see if they vanish or explode now that https://github.com/nerfstudio-project/nerfstudio/pull/1334 is merged in
- w the installation guidelines in the readme.md (with CUDA 11.3, Ubuntu 18.04
It works for me, thanks
https://github.com/nerfstudio-project/nerfstudio/pull/1656, maybe this can help
When we set up training for 10w iterations, a similar phenomenon occurred close to 9w times, we tried to add weight_decay = 1e-6
, which seriously affected the training quality, and the psnr dropped significantly. Is there any other solution currently to help us train more iterations.
There is a current pr that may fix this issue, https://github.com/nerfstudio-project/nerfstudio/pull/1662
Describe the bug
AssertionError: the min value is -9223372036854775808
inoutputs["accumulation"]
after 490 iterations using nerfacto on nerfstudio/redwoods2. I tried different seeds and still get the error at 490 iterations. This does not happen on the other nerfstudio scenes.To Reproduce
ns-download-data --dataset nerfstudio --capture redwoods2
ns-train nerfacto --vis tensorboard --data data/nerfstudio/redwoods2
.Expected behavior
It should train smoothly
Screenshots
Extra context
It seems related to #117, which looks resolved by a previous PR. Let me know if you need more details on the machine/package versions