mli0603 / stereo-transformer

Revisiting Stereo Depth Estimation From a Sequence-to-Sequence Perspective with Transformers. (ICCV 2021 Oral)
Apache License 2.0
659 stars 107 forks source link

Train crashes inside the Attention module #5

Closed rachillesf closed 3 years ago

rachillesf commented 3 years ago

Hi, Thanks for sharing this implementation.

I'm trying to reproduce the paper results by training the model on sceneflow but training is constantly crashing inside the attention model when scaling the query tensor.

Did you have this problem before?

Here is the Traceback:

Traceback (most recent call last): File "main.py", line 250, in main(args_) File "main.py", line 236, in main eval_stats = evaluate(model, criterion, data_loader_val, device, epoch, summary_writer, False) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(*args, kwargs) File "/media/user/Data/user/stereo-transformer/utilities/eval.py", line 34, in evaluate outputs, losses, sampled_disp = forward_pass(model, data, device, criterion, eval_stats, idx, logger) File "/media/user/Data/user/stereo-transformer/utilities/foward_pass.py", line 55, in forward_pass outputs = model(inputs) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/media/user/Data/user/stereo-transformer/module/sttr.py", line 101, in forward attn_weight = self.transformer(feat_left, feat_right, pos_enc) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 112, in forward attn_weight = self._alternating_attn(feat, pos_enc, pos_indexes, hn) File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 60, in _alternating_attn feat = checkpoint(create_custom_self_attn(self_attn), feat, pos_enc, pos_indexes) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 74, in forward outputs = run_function(args) File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 56, in custom_self_attn return module(inputs) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, *kwargs) File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 143, in forward pos_indexes=pos_indexes) File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, *kwargs) File "/media/user/Data/user/stereo-transformer/module/attention.py", line 86, in forward q = q scaling UnboundLocalError: local variable 'q' referenced before assignment

mli0603 commented 3 years ago

Hi @rachillesf , I cannot reproduce the error you showed here.

The traceback is saying q is never defined but still used in line 86. But line 86 of the code does not match what your traceback is reporting.

If you still have problems, can you tell me what python and pytorch version you are using? Because I am inherenting the MultiheadAttention from the official implementation, so it may vary version by version. But last time I checked they are consistent, so version should not be a problem.

rachillesf commented 3 years ago

Hi, @mli0603 thanks for answering. I thought this issue was related to my dataset but I reviewed it several times and I cannot find anything wrong.

I'm using the sceneflow dataloader and the only modification I made was setting a maximum crop size to the function 'random_crop'. So the model can fit on my GPU. The function signature looks like this after the modification:

result = random_crop(360, 450, result, self.split, max_crop_height=512, max_crop_width=512)

I'm running it in an anaconda env with Python 3.6.12 and Torch 1.7.0

This error is appearing to me very frequently but it's not every at the same time. Sometimes it crashes after 2 epochs and other times it crashes at epoch 0

mli0603 commented 3 years ago

Hi @rachillesf , thanks for the information. I will build an identical environment as yours and try training again. I may have broken up the initialization of parameters during my code clean-up. I will update you.

mli0603 commented 3 years ago

Hi @rachillesf, sorry for the long wait. I am able to reproduce your error now. You are not using apex, are you? Here is what happened:

Why q is not computed? q is not computed since neither of two equality conditions were met (i.e. self-attn or cross-attn). This happened because there were nan in the query/key/value (nan != nan, so all the previous "if"s failed). And this was because there was exploding gradients in the iteraiton before crashing, which lead to nan in the parameters. Why math.isfinite never reported this? Well, it turns out loss is finite does not mean the gradient is well-conditioned. So the check was not sufficient.

Why this bug was never seen before? Because of apex. It turns out that apex does more than turning numbers from 32bit to 16bit. It also checks if the gradient is finite. If it isn't, it will reduce the "loss scaling factor" and continue to the next iteration (this is the key, it will skip the current iteration and continue). So it saved me from crashing.

Changes to fix this I have created a new branch called bug-attention-no-query, with the following changes:

How do the changes affect things? First of all, since apex scales the gradient, I cannot gaurantee that training without apex will lead to exact the same result. However, it should not affect the result a lot, because only early iterations will have problems like this (since the network parameter is still initializing and output can be wild). Secondly, I do recommend training with mixed precision because it is faster and less memory consuming. I know there may be drop in performance (people have done extensive comparison), but in my opinion the saving is worth it.

Can you pull the code again and try the new branch? Once you can verify it works, I will start a pull request and merge into master.

rachillesf commented 3 years ago

Hi @mli0603. Thanks for your very detailed reply.

I tested the new branch and seems the problem was fixed. I'll run with apex and compare the result.