pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.01k stars 22.65k forks source link

Pytorch-Caffe2 export: "Arrays are not almost equal to 3 decimals" #13836

Open zerotosingularity opened 6 years ago

zerotosingularity commented 6 years ago

📚 Documentation

Following the "Transfering a Model from PyTorch to Caffe2 and Mobile using ONNX" I get the following exception:

Arrays are not almost equal to 3 decimals

when running

np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)

(with the documentation specifically asking to contact the team when this happens. :))

I have been able to reproduce the error on both Ubuntu 16.04 and Google Colab. Note that I have not install onnx-caffe2 and updated the source code following PR-348.

This is the environment on Ubuntu 16.004

Conda env:

# packages in environment at ~/.conda/envs/torch-onnx:
#
# Name                    Version                   Build  Channel
backcall                  0.1.0                    py36_0
blas                      1.0                         mkl
bleach                    3.0.2                    py36_0
bzip2                     1.0.6                h14c3975_5
ca-certificates           2018.03.07                    0
certifi                   2018.10.15               py36_0
cffi                      1.11.5           py36he75722e_1
cmake                     3.12.2               h52cb24c_0
cycler                    0.10.0                   py36_0
dbus                      1.13.2               h714fa37_1
decorator                 4.3.0                    py36_0
entrypoints               0.2.3                    py36_2
expat                     2.2.6                he6710b0_0
fontconfig                2.13.0               h9420a91_0
freetype                  2.9.1                h8a8886c_1
future                    0.17.1                    <pip>
glib                      2.56.2               hd408876_0
gmp                       6.1.2                h6c8ec71_1
gst-plugins-base          1.14.0               hbbd80ab_1
gstreamer                 1.14.0               hb453b48_1
icu                       58.2                 h9c2bf20_1
intel-openmp              2019.0                      118
ipykernel                 5.1.0            py36h39e3cac_0
ipython                   7.1.1            py36h39e3cac_0
ipython_genutils          0.2.0                    py36_0
ipywidgets                7.4.2                    py36_0
jedi                      0.13.1                   py36_0
jinja2                    2.10                     py36_0
jpeg                      9b                   h024ee3a_2
jsonschema                2.6.0                    py36_0
jupyter                   1.0.0                    py36_7
jupyter_client            5.2.3                    py36_0
jupyter_console           6.0.0                    py36_0
jupyter_core              4.4.0                    py36_0
kiwisolver                1.0.1            py36hf484d3e_0
libcurl                   7.61.1               heec0ca6_0
libedit                   3.1.20170329         h6b74fdf_2
libffi                    3.2.1                hd88cf55_4
libgcc-ng                 8.2.0                hdf63c60_1
libgfortran-ng            7.3.0                hdf63c60_0
libpng                    1.6.35               hbc83047_0
libprotobuf               3.6.1                hd408876_0
libsodium                 1.0.16               h1bed415_0
libssh2                   1.8.0                h9cfc8f7_4
libstdcxx-ng              8.2.0                hdf63c60_1
libuuid                   1.0.3                h1bed415_2
libxcb                    1.13                 h1bed415_1
libxml2                   2.9.8                h26e45fe_1
magma-cuda90              2.3.0                         1    pytorch
markupsafe                1.0              py36h14c3975_1
matplotlib                3.0.1            py36h5429711_0
mistune                   0.8.4            py36h7b6447c_0
mkl                       2019.0                      118
mkl-include               2019.0                      118
mkl_fft                   1.0.6            py36h7dd41cf_0
mkl_random                1.0.1            py36h4414c95_1
mkldnn                    0.16.1                        0    mingfeima
nbconvert                 5.3.1                    py36_0
nbformat                  4.4.0                    py36_0
ncurses                   6.1                  hf484d3e_0
ninja                     1.8.2            py36h6bb024c_1
notebook                  5.7.0                    py36_0
numpy                     1.15.4           py36h1d66e8a_0
numpy-base                1.15.4           py36h81de0dd_0
onnx                      1.3.0                     <pip>
openssl                   1.0.2p               h14c3975_0
pandoc                    2.2.3.2                       0
pandocfilters             1.4.2                    py36_1
parso                     0.3.1                    py36_0
pcre                      8.42                 h439df22_0
pexpect                   4.6.0                    py36_0
pickleshare               0.7.5                    py36_0
pip                       18.1                     py36_0
prometheus_client         0.4.2                    py36_0
prompt_toolkit            2.0.7                    py36_0
protobuf                  3.6.1            py36he6710b0_0
ptyprocess                0.6.0                    py36_0
pycparser                 2.19                     py36_0
pygments                  2.2.0                    py36_0
pyparsing                 2.3.0                    py36_0
pyqt                      5.9.2            py36h05f1152_2
python                    3.6.5                hc3d631a_2
python-dateutil           2.7.5                    py36_0
pytorch-nightly           1.0.0.dev20181109 py3.6_cuda9.0.176_cudnn7.1.2_0    pytorch
pytz                      2018.7                   py36_0
pyyaml                    3.13             py36h14c3975_0
pyzmq                     17.1.2           py36h14c3975_0
qt                        5.9.6                h8703b6f_2
qtconsole                 4.4.2                    py36_0
readline                  7.0                  h7b6447c_5
rhash                     1.3.6                hb7f436b_0
send2trash                1.5.0                    py36_0
setuptools                40.5.0                   py36_0
sip                       4.19.8           py36hf484d3e_0
six                       1.11.0                   py36_1
sqlite                    3.25.2               h7b6447c_0
terminado                 0.8.1                    py36_1
testpath                  0.4.2                    py36_0
tk                        8.6.8                hbc83047_0
tornado                   5.1.1            py36h7b6447c_0
traitlets                 4.3.2                    py36_0
typing                    3.6.4                    py36_0
typing-extensions         3.6.6                     <pip>
wcwidth                   0.1.7                    py36_0
webencodings              0.5.1                    py36_1
wheel                     0.32.2                   py36_0
widgetsnbextension        3.4.2                    py36_0
xz                        5.2.4                h14c3975_4
yaml                      0.1.7                had09818_2
zeromq                    4.2.5                hf484d3e_1
zlib                      1.2.11               ha838bed_2

Pip list:

Package            Version          
------------------ -----------------
backcall           0.1.0            
bleach             3.0.2            
certifi            2018.10.15       
cffi               1.11.5           
cycler             0.10.0           
decorator          4.3.0            
entrypoints        0.2.3            
future             0.17.1           
ipykernel          5.1.0            
ipython            7.1.1            
ipython-genutils   0.2.0            
ipywidgets         7.4.2            
jedi               0.13.1           
Jinja2             2.10             
jsonschema         2.6.0            
jupyter            1.0.0            
jupyter-client     5.2.3            
jupyter-console    6.0.0            
jupyter-core       4.4.0            
kiwisolver         1.0.1            
MarkupSafe         1.0              
matplotlib         3.0.1            
mistune            0.8.4            
mkl-fft            1.0.6            
mkl-random         1.0.1            
nbconvert          5.3.1            
nbformat           4.4.0            
notebook           5.7.0            
numpy              1.15.4           
onnx               1.3.0            
pandocfilters      1.4.2            
parso              0.3.1            
pexpect            4.6.0            
pickleshare        0.7.5            
pip                18.1             
prometheus-client  0.4.2            
prompt-toolkit     2.0.7            
protobuf           3.6.1            
ptyprocess         0.6.0            
pycparser          2.19             
Pygments           2.2.0            
pyparsing          2.3.0            
python-dateutil    2.7.5            
pytz               2018.7           
PyYAML             3.13             
pyzmq              17.1.2           
qtconsole          4.4.2            
Send2Trash         1.5.0            
setuptools         40.5.0           
six                1.11.0           
terminado          0.8.1            
testpath           0.4.2            
torch              1.0.0.dev20181109
tornado            5.1.1            
traitlets          4.3.2            
typing             3.6.4            
typing-extensions  3.6.6            
wcwidth            0.1.7            
webencodings       0.5.1            
wheel              0.32.2           
widgetsnbextension 3.4.2      
zerotosingularity commented 6 years ago

Tried '1.0.0.dev20181112' pytorch-nightly, but the issue remains.

tlc commented 5 years ago

11/30/18 FYI, the tutorial and the source links have been updated, so that you don't need to apply the source mods mentioned above. But the problem of the of the runs not generating matching data still remains.

zerotosingularity commented 5 years ago

@tlc thanks for the update. Is there a way to get informed on whether that issue (export to caffe2) is being addressed? Or is there something I can do myself?

ulisesbussi commented 5 years ago

Hi Came for the same Bug, i'm in Manjaro, get the example from https://pytorch.org/tutorials/advanced/super_resolution_with_caffe2.html?highlight=mobile

i don't know what to do. But if i can help in something, letme know

tlc commented 5 years ago

Can anyone point to any working example of a real world model that transfers well from PyTorch to Caffe2?

KeremTurgutlu commented 5 years ago

Yes you may follow exactly the same steps for alexnet conversion from pytorch to caffe2. Which gives the same outputs as expected both in pytorch and caffe2 backend. Here is the link: https://pytorch.org/docs/stable/onnx.html. This issue with superres might be due to some ops like pixel_shuffle, because other ops are common with Alexnet such as (conv2d, relu).

It works fine if you change the network:

#dummy supperres model
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, 1, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.Upsample(scale_factor=upscale_factor, mode='nearest')

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
ChiefGodMan commented 5 years ago

Yes you may follow exactly the same steps for alexnet conversion from pytorch to caffe2. Which gives the same outputs as expected both in pytorch and caffe2 backend. Here is the link: https://pytorch.org/docs/stable/onnx.html. This issue with superres might be due to some ops like pixel_shuffle, because other ops are common with Alexnet such as (conv2d, relu).

It works fine if you change the network:

#dummy supperres model
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, 1, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.Upsample(scale_factor=upscale_factor, mode='nearest')

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

I have replaced nn.PixelShuffle with nn.Upsample and the test passed. It's great. But I'm curious if tensorrt does not support PixelShuffle op, why not cause an error??? A wrong output answer may puzzled us.