microsoft / CNTK

Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit
https://docs.microsoft.com/cognitive-toolkit/
Other
17.52k stars 4.28k forks source link

Incorrect BatchNorm output (ONNX export) #2990

Closed artbataev closed 6 years ago

artbataev commented 6 years ago

There is some strange behavior of Batch Normalization, while importing model from Pytorch (0.3.1) to CNTK (2.4) using ONNX: most of the outputs are 0.

Very simple code to reproduce problem:

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cntk as C

model = nn.Sequential(nn.BatchNorm2d(1))
model.eval()
dummy_input = np.random.randn(1, 1, 1, 2).astype(np.float32)
dummy_input_torch = Variable(torch.FloatTensor(dummy_input))

print("Some input: ", dummy_input.tolist())
print("Batch norm parameters: running_mean {} | running_var {} | weight {} | bias {}".format(model[0].running_mean.numpy(), 
                                                                                             model[0].running_var.numpy(), 
                                                                                            model[0].weight.data.numpy(), 
                                                                                             model[0].bias.data.numpy()))
print("Pytorch output: ", model(dummy_input_torch).data.numpy().tolist())

torch.onnx.export(model, dummy_input_torch, "mymodel.proto", verbose=False, training=False)
cntk_model = C.Function.load("mymodel.proto", format=C.ModelFormat.ONNX, device=C.device.cpu()) 
print("Cntk output: ", cntk_model.eval({cntk_model.arguments[0]: dummy_input}).tolist())` 

Output:

Some input:  [[[[-1.277238130569458, 0.14103645086288452]]]]
Batch norm parameters: running_mean [0.] | running_var [1.] | weight [0.795401] | bias [0.]
Pytorch output:  [[[[-1.015911340713501, 0.11217997223138809]]]]
Cntk output:  [[[[-1.015911340713501, 0.0]]]]

Pytorch output is correct, but for CNTK one element is 0 (the same behavior is when dealing with larger inputs - only first elements in tensors are correct)

MatthiasRock commented 6 years ago

That could maybe explain why I'm not able to reproduce the same behavior with my Python network as with my BrainScript network... at least if Python and BrainScript do not use the same implementation of batch normalization layer.

Edit: I solved my problem...it has nothing to do with this issue.

spandantiwari commented 6 years ago

Thanks for reporting this. The implementation of ONNX BatchNormalization op in CNTK was updated recently to match the latest opset 6 spec as part of this commit (https://github.com/Microsoft/CNTK/commit/cf839dcdb5f821b854353b1246ef27b1003ea944) .

The numbers are now consistent between CNTK and Pytorch.

Pytorch output: ('Some input: ', [[[[-0.9206454753875732, -2.3230578899383545]]]]) Batch norm parameters: running_mean [0.] | running_var [1.] | weight [0.31872553] | bias [0.] ('Pytorch output: ', [[[[-0.29343175888061523, -0.7404141426086426]]]])

CNTK output: Cntk output: [[[[-0.29343172907829285, -0.7404140830039978]]]]

This change should be in the next release (CNTK2.6). You can try it from the latest master today.

If you are still seeing some discrepancy, please reopen this issue.