qianyu-dlut / MVANet

MIT License
60 stars 7 forks source link

Error when using batch size > 1 #7

Open aravindhv10 opened 1 month ago

aravindhv10 commented 1 month ago

I am having errors due to the line loc_e5, glb_e5 = e5.split([4, 1], dim=0) (https://github.com/qianyu-dlut/MVANet/blob/main/model/MVANet.py#L418) when training with batch size > 1

here, e5 will have leading (5 * batch_size) and hence split([4,1]) (is possible only for 5) is not possible for any batch size > 1

when I dug deeper, the batch index was mixed up (for instance , in https://github.com/qianyu-dlut/MVANet/blob/main/model/MVANet.py#L38)

The exact error i got was:

Traceback (most recent call last): File "./train.py", line 1117, in sideout5, sideout4, sideout3, sideout2, sideout1, final, glb5, glb4, glb3, glb2, glb1, tokenattmap4, tokenattmap3, tokenattmap2, tokenattmap1 = generator.forward( File "./train.py", line 882, in forward loc_e5, glb_e5 = e5.split([4, 1], dim=0) File "/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split return torch._VF.split_with_sizes(self, split_size, dim) RuntimeError: split_with_sizes expects split_sizes to sum exactly to 10 (input tensor's size at dimension 0), but got split_sizes=[4, 1]

Is it possible to fix this while still retaining the exact architecture of the model (finetune on personal datasets starting from the pretrained 80th epoch)?

Xelawk commented 1 month ago

same issue