jchenghu / ExpansionNet_v2

Implementation code of the work "Exploiting Multiple Sequence Lengths in Fast End to End Training for Image Captioning"
https://arxiv.org/abs/2208.06551
MIT License
83 stars 25 forks source link

Errors when running demo.py #3

Closed ashaychangwani closed 1 year ago

ashaychangwani commented 1 year ago

Here is my environment: python==3.7.9 torch==1.9.0 torchvision==0.10.0

When executing the demo with the following command: python3.7 demo.py --load_path /tmp/ac8832/rf_model.pth

I am getting this error: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). z = torch.matmul(torch.tensor(query_exp), x_key.transpose(-1, -2)) / torch.sqrt(self.d_model) Traceback (most recent call last): File "demo.py", line 99, in <module> mode='beam_search', **beam_search_kwargs) File "/tmp/ac8832/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/captioning_model.py", line 48, in forward sample_or_max=sample_or_max) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/captioning_model.py", line 117, in beam_search cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/End_ExpansionNet_v2.py", line 94, in forward_enc x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) File "/tmp/ac8832/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/layers.py", line 106, in forward x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask)) File "/tmp/ac8832/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/layers.py", line 52, in forward z = torch.matmul(torch.tensor(query_exp), x_key.transpose(-1, -2)) / torch.sqrt(self.d_model)

When I fix it by wrapping self.d_model with torch.tensor, I get the following error:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). z = torch.matmul(torch.tensor(query_exp), x_key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.d_model)) Traceback (most recent call last): File "demo.py", line 99, in <module> mode='beam_search', **beam_search_kwargs) File "/tmp/ac8832/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/captioning_model.py", line 48, in forward sample_or_max=sample_or_max) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/captioning_model.py", line 124, in beam_search apply_log_softmax=True) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/End_ExpansionNet_v2.py", line 125, in forward_dec cross_attention_mask=pad_mask) File "/tmp/ac8832/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/layers.py", line 204, in forward x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask)) File "/tmp/ac8832/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/ac8832/WhiskerHacks/ExpansionNet_v2/models/layers.py", line 151, in forward mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \ AttributeError: 'NoneType' object has no attribute 'unsqueeze'

jchenghu commented 1 year ago

Hi @ashaychangwani

Thank you for the report. I introduced coding errors in the previous version.

The new commit should have fixed these problem. Thank you for pointing these problems out, let me know if everything is ok now.

Best regards, Jia