f-dangel / unfoldNd

(N=1,2,3)-dimensional unfold (im2col) and fold (col2im) in PyTorch
MIT License
82 stars 6 forks source link

outputs are different from conv3d #24

Closed charlielam0615 closed 3 years ago

charlielam0615 commented 3 years ago

Thanks for sharing this library! I try to extend the unfold version of conv2d to conv3d, the output difference seems be to not negligible albeit small. Any clues why this happens?

import torch

import unfoldNd

torch.manual_seed(0)

inputs = torch.randn(20, 16, 10, 50, 100)
w = torch.randn(33, 16, 3, 5, 2)
stride = (1, 1, 1)
padding = (0, 0, 0)

inp_unf = unfoldNd.unfoldNd(inputs, w.shape[2:], stride=stride, padding=padding)
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
lib_outputs = out_unf.view(20, 33, 8, 46, 99)
torch_outputs = torch.nn.functional.conv3d(inputs, w)

print((lib_outputs-torch_outputs).abs().max())

if torch.allclose(torch_outputs, lib_outputs):
    print("✔ Outputs of conv3d and unfoldNd.UnfoldNd match.")
else:
    raise AssertionError("❌ Outputs don't match")

Outputs:

tensor(0.0001)
Traceback (most recent call last):
  File "test_unfoldnd.py", line 54, in <module>
    raise AssertionError("❌ Outputs don't match")
AssertionError: ❌ Outputs don't match
f-dangel commented 3 years ago

Hey,

thanks for your report. My first guess for those discrepancies would be floating precision. Could you provide the relative mismatch between mismatching values by adding something like

for lib_out, torch_out in zip(lib_outputs.flatten(), torch_outputs.flatten()):
    if not torch.allclose(lib_out, torch_out):
        print(lib_out, torch_out)

to your example?

charlielam0615 commented 3 years ago

It looks like your guess is right. Here's part of the output.

...
tensor(-0.7282) tensor(-0.7282)
tensor(0.5448) tensor(0.5448)
tensor(-0.2942) tensor(-0.2942)
tensor(0.0671) tensor(0.0671)
tensor(0.0840) tensor(0.0840)
tensor(-1.2255) tensor(-1.2254)
tensor(0.5453) tensor(0.5453)
tensor(0.1272) tensor(0.1272)
tensor(0.0616) tensor(0.0616)
tensor(0.0478) tensor(0.0478)
tensor(-0.5128) tensor(-0.5128)
tensor(0.0550) tensor(0.0550)
tensor(-0.8851) tensor(-0.8851)
tensor(0.5813) tensor(0.5813)
tensor(-0.1910) tensor(-0.1910)
tensor(-0.4936) tensor(-0.4936)
tensor(-0.0498) tensor(-0.0498)
...

They all look the same because of the printing decimal setting. Strange that the floating precision discrepancy could lead to this much of a difference (max difference 0.0001), though.

f-dangel commented 3 years ago

I have already experienced such behavior multiple times. One fix is to increase the tolerances in torch.allclose, e.g. using rtol=5e-05, atol=1e-07 instead of the default settings.