Hi! Thanks a lot for developing this work! I've been playing around it but I'm still not able to find the code for Equation 4 (where messages are summed over for each node). I'd really appreciate if you can help me finding and apologies if I'm missing something very obvious.
I can see that when you call attention(query, key, value) the resulting tensor will have shape (1, 64, 4, #keypoints) (considering feature dimension of 256 which gives 4 heads with 64 values). From my understanding, this will calculate the typical transformation operations that you need for the attention weight. And indeed, you seem to be multiplying all attention coefficients by all the values at each head.
However, after this step, I don't see the summation that you mention in Equation 4. Instead, you just seem to be flattening out everything for each keypoint when you reshape the returning shape to (batch_dim, self.dim*self.num_heads, -1) (line 107), and then run a final MLP on top of this. Therefore, it seems to me that rather than summing over all values to get the final message per keypoint, you are running an MLP over all the concatenated attention values to get your message per keypoint (in practice, merging all heads, but not summing over all weighted values).
Any help/tip will really be appreciated, thanks! :)
Hi! Thanks a lot for developing this work! I've been playing around it but I'm still not able to find the code for Equation 4 (where messages are summed over for each node). I'd really appreciate if you can help me finding and apologies if I'm missing something very obvious.
I can see that when you call
attention(query, key, value)
the resulting tensor will have shape(1, 64, 4, #keypoints)
(considering feature dimension of 256 which gives 4 heads with 64 values). From my understanding, this will calculate the typical transformation operations that you need for the attention weight. And indeed, you seem to be multiplying all attention coefficients by all thevalues
at each head.However, after this step, I don't see the summation that you mention in Equation 4. Instead, you just seem to be flattening out everything for each keypoint when you reshape the returning shape to
(batch_dim, self.dim*self.num_heads, -1)
(line 107), and then run a final MLP on top of this. Therefore, it seems to me that rather than summing over all values to get the final message per keypoint, you are running an MLP over all the concatenated attention values to get your message per keypoint (in practice, merging all heads, but not summing over all weighted values).Any help/tip will really be appreciated, thanks! :)