szymonmaszke / torchlayers

Shape and dimension inference (Keras-like) for PyTorch layers and neural networks
MIT License
568 stars 46 forks source link

Runtime Error size mismatch #8

Closed m-rph closed 4 years ago

m-rph commented 4 years ago

Hi

I am trying to build the following:

c = nn.Sequential(
        tl.Conv2d(32, kernel_size=8, stride=4),
        nn.ReLU(),
        tl.Conv2d(64, kernel_size=4, stride=2),
        nn.ReLU(),
        tl.Conv2d(64, kernel_size=3, stride=1),
        nn.ReLU(),
        tl.Linear(512),
        nn.ReLU()
    )

tl.build(c, torch.randn(1, 3, 84, 84))

but I am getting

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torchlayers/__init__.py", line 67, in build
    module(*args, **kwargs)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torchlayers/_dev_utils/infer.py", line 214, in forward
    return infered_module(inputs, *args, **kwargs)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/stelios/anaconda3/envs/HM/lib/python3.7/site-packages/torch/nn/functional.py", line 1372, in linear
    output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [448 x 7], m2: [64 x 512] at /tmp/pip-req-build-7mav6f4d/aten/src/TH/generic/THTensorMath.cpp:197

It works just fine with flatten right before the linear.

szymonmaszke commented 4 years ago

tl.Conv2d returns 4D tensor of shape [batch, channels, width, height], after flatten it becomes [batch, channels * width * height] which is quite standard practice.

This goes into torch.nn.Linear. If you flatten everything is correctly inferred (in_features will be channels * width * height). Please beware of images of different size as it will not work due to stride different than 1.

torch.nn.Linear supports any size of tensor though so torchlayers should infer size of last dimension instead of first as it is currently done. This should be resolved soon, will close this issue when it lands in nightly.