wouterkool / attention-learn-to-route

Attention based model for learning to solve different routing problems
MIT License
1.08k stars 341 forks source link

Rewrite last part of MultiHeadAttention.forward #24

Closed ahmad-PH closed 3 years ago

ahmad-PH commented 3 years ago

It should be easier to understand now. I struggled understanding this part specifically because the matrix multiplication and summation from article are compressed into a single matrix multiplication. Now they are two separate steps. I also tested the new MultiHeadAttention class with the following snippet to make sure nothing functional changed:

batch_size = 3
n_query = 2
graph_size = 4
input_dim = 5

h = torch.normal(0, 0.5, [batch_size, graph_size, input_dim])
q = torch.normal(10., 1., [batch_size, n_query, input_dim])

torch.random.manual_seed(0)
attn = MultiHeadAttention(n_heads=6, input_dim=input_dim, key_dim=5, embed_dim=18)

torch.random.manual_seed(0)
attn_ref = MultiHeadAttentionRefactored(n_heads=6, input_dim=input_dim, key_dim=5, embed_dim=18)

out = attn.forward(q, h)
out_ref = attn_ref(q, h)

print(torch.allclose(out, out_ref))

Which prints True every time. (MultiHeadAttentionRefactored is the name I gave to the modified class)

wouterkool commented 3 years ago

Thanks for this suggestions, this is indeed a bit more clear. However, from my quick experiments it seems that this implementation requires a bit more memory and is slightly slower. Therefore I did not merge your request, but I have added a comment with your implementation for clarity. Thanks!

ahmad-PH commented 3 years ago

No problem :) I'm glad it was useful. I hadn't thought of checking speed and memory myself (sorry!), so it's a good thing you did :). I will take that into account in future suggestions.