djiajunustc / TransVG

157 stars 26 forks source link

query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) #3

Closed OliverHuang1220 closed 3 years ago

OliverHuang1220 commented 3 years ago

Hi@djiajunustc, I find 'self.query_embed' in your detr.py is not be assignmented.But in detr/models/detr.py,self.query_embed = nn.Embedding(num_queries, hidden_dim). I tried to change it once, but there was an error of dimension mismatch. Look forward to your reply! When i add self.query_embed = nn.Embedding(num_queries, hidden_dim),The error is as follows:


Traceback (most recent call last): File "train.py", line 282, in main(args) File "train.py", line 239, in main args, model, data_loader_train, optimizer, device, epoch, args.clip_max_norm File "/home1/huangqiangHD/TransVG/engine.py", line 38, in train_one_epoch output = model(img_data, text_data) File "/home/huangqiang/.conda/envs/trans_vg/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(*input, kwargs) File "/home1/huangqiangHD/TransVG/models/trans_vg.py", line 39, in forward visu_src = self.visu_proj(visu_src) # (NB)xC File "/home/huangqiang/.conda/envs/trans_vg/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(input, kwargs) File "/home/huangqiang/.conda/envs/trans_vg/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 87, in forward return F.linear(input, self.weight, self.bias) File "/home/huangqiang/.conda/envs/trans_vg/lib/python3.6/site-packages/torch/nn/functional.py", line 1372, in linear output = input.matmul(weight.t()) RuntimeError: size mismatch, m1: [5120 x 20], m2: [256 x 256] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:290


djiajunustc commented 3 years ago

There is no self.query_embed in detr.py

djiajunustc commented 3 years ago

query_embed is used to perform detection in transformer decoder layers in DETR. There are no transformer decoder layers in my method. Why should you use query embed?

OliverHuang1220 commented 3 years ago

I got it.I ignored 'dec_layer=0' in train.sh!