Closed rachillesf closed 3 years ago
Hi @rachillesf , I cannot reproduce the error you showed here.
The traceback is saying q
is never defined but still used in line 86. But line 86 of the code does not match what your traceback is reporting.
If you still have problems, can you tell me what python and pytorch version you are using? Because I am inherenting the MultiheadAttention
from the official implementation, so it may vary version by version. But last time I checked they are consistent, so version should not be a problem.
Hi, @mli0603 thanks for answering. I thought this issue was related to my dataset but I reviewed it several times and I cannot find anything wrong.
I'm using the sceneflow dataloader and the only modification I made was setting a maximum crop size to the function 'random_crop'. So the model can fit on my GPU. The function signature looks like this after the modification:
result = random_crop(360, 450, result, self.split, max_crop_height=512, max_crop_width=512)
I'm running it in an anaconda env with Python 3.6.12 and Torch 1.7.0
This error is appearing to me very frequently but it's not every at the same time. Sometimes it crashes after 2 epochs and other times it crashes at epoch 0
Hi @rachillesf , thanks for the information. I will build an identical environment as yours and try training again. I may have broken up the initialization of parameters during my code clean-up. I will update you.
Hi @rachillesf, sorry for the long wait. I am able to reproduce your error now. You are not using apex, are you? Here is what happened:
Why q
is not computed?
q
is not computed since neither of two equality conditions were met (i.e. self-attn or cross-attn). This happened because there were nan
in the query/key/value (nan != nan
, so all the previous "if"s failed). And this was because there was exploding gradients in the iteraiton before crashing, which lead to nan
in the parameters. Why math.isfinite
never reported this? Well, it turns out loss is finite does not mean the gradient is well-conditioned. So the check was not sufficient.
Why this bug was never seen before? Because of apex. It turns out that apex does more than turning numbers from 32bit to 16bit. It also checks if the gradient is finite. If it isn't, it will reduce the "loss scaling factor" and continue to the next iteration (this is the key, it will skip the current iteration and continue). So it saved me from crashing.
Changes to fix this I have created a new branch called bug-attention-no-query, with the following changes:
How do the changes affect things? First of all, since apex scales the gradient, I cannot gaurantee that training without apex will lead to exact the same result. However, it should not affect the result a lot, because only early iterations will have problems like this (since the network parameter is still initializing and output can be wild). Secondly, I do recommend training with mixed precision because it is faster and less memory consuming. I know there may be drop in performance (people have done extensive comparison), but in my opinion the saving is worth it.
Can you pull the code again and try the new branch? Once you can verify it works, I will start a pull request and merge into master.
Hi @mli0603. Thanks for your very detailed reply.
I tested the new branch and seems the problem was fixed. I'll run with apex and compare the result.
Hi, Thanks for sharing this implementation.
I'm trying to reproduce the paper results by training the model on sceneflow but training is constantly crashing inside the attention model when scaling the query tensor.
Did you have this problem before?
Here is the Traceback:
Traceback (most recent call last): File "main.py", line 250, in
main(args_)
File "main.py", line 236, in main
eval_stats = evaluate(model, criterion, data_loader_val, device, epoch, summary_writer, False)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, kwargs)
File "/media/user/Data/user/stereo-transformer/utilities/eval.py", line 34, in evaluate
outputs, losses, sampled_disp = forward_pass(model, data, device, criterion, eval_stats, idx, logger)
File "/media/user/Data/user/stereo-transformer/utilities/foward_pass.py", line 55, in forward_pass
outputs = model(inputs)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, *kwargs)
File "/media/user/Data/user/stereo-transformer/module/sttr.py", line 101, in forward
attn_weight = self.transformer(feat_left, feat_right, pos_enc)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, kwargs)
File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 112, in forward
attn_weight = self._alternating_attn(feat, pos_enc, pos_indexes, hn)
File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 60, in _alternating_attn
feat = checkpoint(create_custom_self_attn(self_attn), feat, pos_enc, pos_indexes)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint
return CheckpointFunction.apply(function, preserve, args)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 74, in forward
outputs = run_function(args)
File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 56, in custom_self_attn
return module(inputs)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, *kwargs)
File "/media/user/Data/user/stereo-transformer/module/transformer.py", line 143, in forward
pos_indexes=pos_indexes)
File "/home/user/anaconda3/envs/sttr/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, *kwargs)
File "/media/user/Data/user/stereo-transformer/module/attention.py", line 86, in forward
q = q scaling
UnboundLocalError: local variable 'q' referenced before assignment