Closed ganesh3 closed 3 years ago
Thank you for your interest!
The problem is in the line
weight = self.weight_shared(self.hidden_dim, 1, 1)
You forgot to call the repeat
method. The correct version is:
weight = self.weight_shared.repeat(self.hidden_dim, 1, 1)
For reference, this should be a working code snippet similar to what was done in the video
import torch
import torch.nn as nn
class Conv1dDepthwiseShared(nn.Module):
def __init__(self, hidden_dim, kernel_size, k=1):
super().__init__()
self.hidden_dim = hidden_dim
self.weight_shared = nn.Parameter(
torch.rand(
k,
1,
kernel_size,
)
)
self.bias_shared = nn.Parameter(torch.rand(k))
def forward(self, x):
weight = self.weight_shared.repeat(self.hidden_dim, 1, 1)
bias = self.bias_shared.repeat(self.hidden_dim)
res = torch.nn.functional.conv1d(
x,
weight=weight,
bias=bias,
groups=self.hidden_dim,
)
return res
if __name__ == "__main__":
n_samples, hidden_dim, n_patches = 2, 16, 25
k = 7
x = torch.randn(n_samples, hidden_dim, n_patches)
module_conv = Conv1dDepthwiseShared(hidden_dim, n_patches, k)
module_linear = nn.Linear(n_patches, k)
print(sum(p.numel() for p in module_conv.parameters() if p.requires_grad))
print(sum(p.numel() for p in module_linear.parameters() if p.requires_grad))
out_conv = module_conv(x).reshape(n_samples, hidden_dim, k)
out_linear = module_linear(x)
# Manually set the same weights
module_conv.weight_shared.data[:, 0, :] = module_linear.weight.data
module_conv.bias_shared.data[:] = module_linear.bias.data
out_conv = module_conv(x).reshape(n_samples, hidden_dim, k)
out_linear = module_linear(x)
print(torch.allclose(out_conv, out_linear, atol=1e-6, rtol=0))
Thanks. Its resolved. Ideally, it should throw an error of tensor size mismatch. Why does it say not callable?
Great!
Well, you are taking the self.weight_shared
which is of type torch.nn.Parameter
and you are calling it (using self.weight_shared(...)
). However, it does not have the __call__
method implemented and that is why you get the error:
TypeError: 'Parameter' object is not callable
Ok. Thanks.
Eu
2021-10-16 10:48 GMT-05:00, ganesh3 @.***>:
Closed #2.
-- You are receiving this because you are subscribed to this thread. Reply to this email directly or view it on GitHub: https://github.com/jankrepl/mildlyoverfitted/issues/2#event-5473212158
Hi,
Thanks for the video on MLP mixer using pytorch and flax. I am facing an error while running the code. I am sharing the code that I have written and the error I am getting: