sail-sg / mvp

NeurIPS-2021: Direct Multi-view Multi-person 3D Human Pose Estimation
Apache License 2.0
329 stars 34 forks source link

About the MvP-Dense Attention module #15

Open liqikai9 opened 2 years ago

liqikai9 commented 2 years ago

In your paper, you mention that you have replaced the projective attention with dense attention module, here is the results:

mvp_dense

I wonder how did you run the experiment? How can I modify your code to run the experiment? Which module should I modify?

twangnh commented 2 years ago

hi, you can modify the attention layer, the current code does not include that

liqikai9 commented 2 years ago

hi, you can modify the attention layer, the current code does not include that

Thanks for your quick reply! Could you please offer me some help to modify the attention layer?

  1. Currently I am using nn.MultiheadAttention as the cross attention layer instead of your Projective attention module. But the results I got after ~50 epochs is still not valid, with AP@50 near 0.05, and MPJPE near 270 mm. It is far from the results you report in the paper. What attention layer should I choose?

  2. Once I choose the nn.MultiheadAttention, what should the query, key, and value be? Should I keep your ray positional embedding and rayconv modules as well?

twangnh commented 2 years ago

yes, you can use nn.MultiheadAttention, it is similar to the usage of cross attention in detr, with query as the joint query, key and value as the multi-view image feature maps. the ray embedding should be kept

liqikai9 commented 2 years ago

yes, you can use nn.MultiheadAttention, it is similar to the usage of cross attention in detr, with query as the joint query, key and value as the multi-view image feature maps. the ray embedding should be kept

Now my experiment setting is as follows.

  1. Here is my cross attention layer:
    self.cross_attn = nn.MultiheadAttention(d_model,
                                        n_heads,
                                        dropout=dropout)
  2. In the forward, the query is the output from the self-attention, plus the query positional encoding. value is just the features from CNN, and the key is from the CNN features concated with ray embedding then go through the rayconv layer, which is a nn.Linear layer.
# the calculation of value
src_views_cat = torch.cat([src.flatten(2) for src in src_views], dim=-1).permute(0, 2, 1)
v = src_views_cat.flatten(0, 1).unsqueeze(0).expand(batch_size, -1, -1)

# the calculation of key
cam_cat = torch.cat([cam.flatten(1,2) for cam in src_views_with_rayembed], dim=1)
input_flatten = torch.cat([src_views_cat, cam_cat], dim=-1)
memory = self.rayconv(input_flatten)
k = memory.flatten(0, 1).unsqueeze(0).expand(batch_size, -1, -1)

# the cros_attn I add
tgt2 = self.cross_attn(
    (tgt + query_pos).transpose(0, 1), 
    k.transpose(0, 1),
    v.transpose(0, 1),
)[0].transpose(0, 1)

Do you think this setting is reasonable and able to derive the results you report in the paper?

liqikai9 commented 2 years ago

I have tried the above setting and didn't get results near MPJPE 114.5mm. So I was wondering if there are any fields in the config I should also modify to make the training results better? Thanks again and I am looking forward to get your reply!

twangnh commented 2 years ago

i think your modification is correct, how much performance did you get?

liqikai9 commented 2 years ago

Here is the result i got after 74 epochs:

2022-04-14 21:17:41,738 Test: [0/323]   Time: 239.402s (239.402s)   Speed: 0.0 samples/s    Data: 238.601s (238.601s)   Memory 455115776.0
2022-04-14 21:18:14,152 Test: [100/323] Time: 0.327s (2.691s)   Speed: 15.3 samples/s   Data: 0.000s (2.363s)   Memory 455115776.0
2022-04-14 21:18:46,653 Test: [200/323] Time: 0.320s (1.514s)   Speed: 15.6 samples/s   Data: 0.000s (1.187s)   Memory 455115776.0
2022-04-14 21:19:19,709 Test: [300/323] Time: 0.322s (1.121s)   Speed: 15.5 samples/s   Data: 0.000s (0.793s)   Memory 455115776.0
2022-04-14 21:19:26,819 Test: [322/323] Time: 0.320s (1.067s)   Speed: 15.6 samples/s   Data: 0.000s (0.739s)   Memory 455115776.0
2022-04-14 21:19:35,339 +--------------+-------+-------+-------+-------+-------+-------+
| Threshold/mm |   25  |   50  |   75  |  100  |  125  |  150  |
+--------------+-------+-------+-------+-------+-------+-------+
|      AP      |  0.00 |  0.00 |  0.03 |  0.60 |  2.21 |  4.09 |
|    Recall    |  0.00 |  0.02 |  0.86 |  5.17 | 11.30 | 16.82 |
| recall@500mm | 91.33 | 91.33 | 91.33 | 91.33 | 91.33 | 91.33 |
+--------------+-------+-------+-------+-------+-------+-------+
2022-04-14 21:19:35,340 MPJPE: 263.41mm
twangnh commented 2 years ago

260 is too low. the full attention result is not got from the current setting, you could try turning the learning rate, and also, check the best results during the training process

liqikai9 commented 2 years ago

The above is the best result I got. And I tried to turn the lr to 1e-5 but it didn't get better.

the full attention result is not got from the current setting,

I also have the add, norm and ffn after the cross_attn:

tgt2 = (bounding[:, 0].unsqueeze(-1) *
        tgt2.view(batch_size, nbins, -1))
tgt3 = tgt + self.dropout1(tgt2)
tgt3 = self.norm1(tgt3)

# ffn
tgt3 = self.forward_ffn(tgt3)
# tgt: [b, num_queries, embed_dims] 
return tgt3

Actually, I am confused about the function of bounding in the code. Could you explain its function?

liqikai9 commented 2 years ago

How did you get your result of Table 5 in your paper? Do you have any suggestions that I should change to make my result approach yours?

image