berniwal / swin-transformer-pytorch

Implementation of the Swin Transformer in PyTorch.
https://arxiv.org/pdf/2103.14030.pdf
MIT License
794 stars 129 forks source link

relative pos embedding errs out with "IndexError: tensors used as indices must be long, byte or bool tensors" #2

Closed lessw2020 closed 3 years ago

lessw2020 commented 3 years ago

Very big thanks for making this implementation! I just upgraded to the relative pos embedding update from an hour ago and in trying to train get this type error.

---> 32         y_pred = model(images)
     33         #print(f" y_pred = {y_pred}")
     34         #print(f" y_pred shape = {y_pred.shape}")

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, img)
    229 
    230     def forward(self, img):
--> 231         x = self.stage1(img)
    232         x = self.stage2(x)
    233         x = self.stage3(x)

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
    189         x = self.patch_partition(x)
    190         for regular_block, shifted_block in self.layers:
--> 191             x = regular_block(x)
    192             x = shifted_block(x)
    193         return x.permute(0, 3, 1, 2)

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
    148 
    149     def forward(self, x):
--> 150         x = self.attention_block(x)
    151         x = self.mlp_block(x)
    152         return x

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x, **kwargs)
     21 
     22     def forward(self, x, **kwargs):
---> 23         return self.fn(x, **kwargs) + x
     24 
     25 

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x, **kwargs)
     31 
     32     def forward(self, x, **kwargs):
---> 33         return self.fn(self.norm(x), **kwargs)
     34 
     35 

~\anaconda3\envs\fastai2\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~\cdetr\cdetr_utils\transformer\swin_transformer.py in forward(self, x)
    116 
    117         if self.relative_pos_embedding:
--> 118             dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
    119         else:
    120             dots += self.pos_embedding

IndexError: tensors used as indices must be long, byte or bool tensors
berniwal commented 3 years ago

Thank you very much for reporting the error! What torch and numpy version are you using and does running the example also not work?

lessw2020 commented 3 years ago

Thanks for the fast reply. 1 - The example code works fine for a)regular embedding and b)with my added init code. 2 - The example code fails with same error for a)using relative pos embedding and b)with my added init code.

*I wanted to confirm my init wasnt' somehow messing up the embedding. Here's what I ran (with a quad sample of T/F for using relative pos embedding and init: swin_test_run

lessw2020 commented 3 years ago

versions: np = 1.18.3 torch = 1.7.1

These might be older so let me upgrade as possible fix.

berniwal commented 3 years ago

Alright strange for me it does also work with those numpy and torch versions. What python version do you use?

lessw2020 commented 3 years ago

Hi - it now works for me after upgrading numpy to 1.19: swin_test_success

lessw2020 commented 3 years ago

for python version I have: python_version

Note that I am on Win10 - not sure if that would affect things but since most are on linux, it could be an issue.

Anyway, it is fixed with the upgrade to numpy 1.19 so I think if people hit this, upgrading numpy would be the quickest resolution.

Thanks again both for making the swin_impl and for the fast help on this issue. I'll test out training with the relative pos embedding next :)

berniwal commented 3 years ago

Perfect, thank you again for reporting this issue. Have fun! :)

yan9qu commented 3 years ago

I also got this err with torch 1.8.1 and numpy 1.19.2. I tried to change the 119 of swim_transformer.py to dots += self.pos_embedding[self.relative_indices[:, :, 0].type(torch.long), self.relative_indices[:, :, 1].type(torch.long)] Finally, the test code run. I want to know if this fix will let the result from this net decay? Forgive my poor expression. Thank you.