jamesli1618 / Obj-GAN

Obj-GAN - Official PyTorch Implementation
283 stars 70 forks source link

box generation batch size #28

Open CarlosDominguezBecerril opened 3 years ago

CarlosDominguezBecerril commented 3 years ago

When training the box generation and only changing the batch size to a number bigger than 1 I get the following error. Does anyone know how to fix it?

Traceback (most recent call last):

File "sample.py", line 160, in <module> optimizer=optimizer, resume=opt.resume, is_training=opt.is_training)

File "/home/carlos/Desktop/Trabajo/Projects/obj_gan/Obj-GAN-master/box_generation/seq2seq/trainer/supervised_trainer.py", line 223, in train dev_data=dev_data, is_training=is_training)

File "/home/carlos/Desktop/Trabajo/Projects/obj_gan/Obj-GAN-master/box_generation/seq2seq/trainer/supervised_trainer.py", line 154, in _train_epoches target_w_variables, target_h_variables, encoder, decoder, is_training, step)

File "/home/carlos/Desktop/Trabajo/Projects/obj_gan/Obj-GAN-master/box_generation/seq2seq/trainer/supervised_trainer.py", line 70, in _train_batch target_h_variables, is_training=is_training)

File "/home/carlos/anaconda3/envs/objgan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__ result = self.forward(*input, **kwargs)

File "/home/carlos/Desktop/Trabajo/Projects/obj_gan/Obj-GAN-master/box_generation/seq2seq/models/DecoderRNN.py", line 260, in forward next_y_decoder_input, is_training=is_training)

File "/home/carlos/Desktop/Trabajo/Projects/obj_gan/Obj-GAN-master/box_generation/seq2seq/models/DecoderRNN.py", line 116, in forward_step l_decoder_input, context), dim=2)

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 2. Got 1 and 2 in dimension 0 at /pytorch/aten/src/TH/generic/THTensorMath.cpp:3616

Regards,

Carlos

fyw1999 commented 3 years ago

Because this line of code " combined_decoder_input = torch.cat((xy_decoder_input, wh_decoder_input, l_decoder_input), dim=2)" try to concatenate three tensors.The dimensions of these three tensors are (1,bsz,aug_size),(1,bsz,aug_size) and (bsz,1,embedding_dim) respectively.So,only when bsz=1, the code can work correctly.