Closed ahmad-PH closed 3 years ago
I feel like changing lines 46 - 47 from graph_encoder.py
should help fix the issue. key_dim
should be changed to val_dim
:
if embed_dim is not None:
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
->
if embed_dim is not None:
self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))
because by applying W_out we are trying to convert val_dim
to embed_dim
, not key_dim
to embed_dim
.
(specifying key_dim
manually makes key_dim
and val_dim
different, thus causing the issue).
Thanks, you are right, I have fixed this!
@wouterkool I think the topic is self-explanatory. Here's a small example that reproduces the error (I was just testing with some random data to understand the
forward
function):This will cause the following error:
Which is basically saying that the
torch.mm
function is receiving arguments with bad sizes:changing the constructor arguments and passing
key_dim=None
will make the error go away.