Open CarlosDominguezBecerril opened 3 years ago
Because this line of code " combined_decoder_input = torch.cat((xy_decoder_input, wh_decoder_input, l_decoder_input), dim=2)" try to concatenate three tensors.The dimensions of these three tensors are (1,bsz,aug_size),(1,bsz,aug_size) and (bsz,1,embedding_dim) respectively.So,only when bsz=1, the code can work correctly.
When training the box generation and only changing the batch size to a number bigger than 1 I get the following error. Does anyone know how to fix it?
Regards,
Carlos