wouterkool / attention-learn-to-route

Attention based model for learning to solve different routing problems
MIT License
1.04k stars 337 forks source link

Manually specifying key_dim in MultiHeadAttention class causing size mismatch #23

Closed ahmad-PH closed 3 years ago

ahmad-PH commented 3 years ago

@wouterkool I think the topic is self-explanatory. Here's a small example that reproduces the error (I was just testing with some random data to understand the forward function):

batch_size = 3
n_query = 2
graph_size = 4
input_dim = 5

h = torch.normal(0, 0.5, [batch_size, graph_size, input_dim])
q = torch.normal(0, 0.5. ,[batch_size, n_query, input_dim])

attn = MultiHeadAttention(n_heads=6, input_dim=input_dim, embed_dim=12, key_dim=10)

out = attn.forward(q, h)

This will cause the following error:

Traceback (most recent call last):
  File "attention.py", line 123, in <module>
    out = attn.forward(q, h)
  File "attention.py", line 104, in forward
    self.W_out.view(-1, self.embed_dim)
RuntimeError: size mismatch, m1: [6 x 12], m2: [30 x 12] at /opt/conda/conda-bld/pytorch_1591914855613/work/aten/src/TH/generic/THTensorMath.cpp:41

Which is basically saying that the torch.mm function is receiving arguments with bad sizes:

        out = torch.mm(
            heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
            self.W_out.view(-1, self.embed_dim)
        ).view(batch_size, n_query, self.embed_dim)

changing the constructor arguments and passing key_dim=None will make the error go away.

ahmad-PH commented 3 years ago

I feel like changing lines 46 - 47 from graph_encoder.py should help fix the issue. key_dim should be changed to val_dim:

if embed_dim is not None:
            self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))

->

if embed_dim is not None:
            self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))

because by applying W_out we are trying to convert val_dim to embed_dim, not key_dim to embed_dim. (specifying key_dim manually makes key_dim and val_dim different, thus causing the issue).

wouterkool commented 3 years ago

Thanks, you are right, I have fixed this!