Samsung / ONE

On-device Neural Engine
Other
437 stars 158 forks source link

[one-import-onnx] Bug in batchnorm exported with training mode #12087

Open jinevening opened 11 months ago

jinevening commented 11 months ago

What

If an onnx model includes a batchnorm Op and is exported in a training mode, the converted tflite model works differently from the original onnx model.

import torch
from torch import nn

class NeuralNetwork(nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.Network = nn.Sequential(
        nn.BatchNorm2d(512,track_running_stats=True),
    )   

  def forward(self,x):
    output = self.Network(x)
    return output

model = NeuralNetwork()
input = torch.randn(1,512, 127, 1)
output = model(input)

from torch.onnx import TrainingMode

torch.onnx.export(model, input, "bn.onnx",
                  verbose=True,
                  export_params=True,
                  opset_version=12,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'],
                  training=TrainingMode.TRAINING)

# The output of bn.onnx is different the output of tflite generated by our toolchain

If training=TrainingMode.TRAINING is removed, everything works fine.

jinevening commented 11 months ago

I used the below script to compare the result of onnx and tflite.

$ onecc import onnx -i bn.onnx -o bn.circle --save_intermediate
import tensorflow as tf
import numpy as np
import onnxruntime

onnx_path = 'bn.onnx'
tflite_path = 'bn.tflite'

sess = onnxruntime.InferenceSession(onnx_path, None)
input_name = sess.get_inputs()[0].name  

i = np.random.rand(1,512,127,1).astype(np.float32)
onnx_o = sess.run([], {input_name: i}) 

print(onnx_o)

tf_interp = tf.lite.Interpreter(tflite_path)
tf_interp.allocate_tensors()
output = tf_interp.get_output_details()[0]  # Model has single output.
input = tf_interp.get_input_details()[0]  # Model has single input.
tf_interp.set_tensor(input['index'], i)
tf_interp.invoke()
tflite_o = tf_interp.get_tensor(output['index'])

print(tflite_o)

print(np.allclose(onnx_o, tflite_o, rtol=1.e-5, atol=1.e-5))
seanshpark commented 11 months ago

Q) are the constant values of BN same for onnx vs tflite ? I'm curious if onnx-runtime does something additional if onnx has training mode...

jinevening commented 11 months ago

Constant values are changed when onnx is converted to tflite. In onnx, they are mean/variance, but in tflite, they are channelwise mul/add.

I'm curious if onnx-runtime does something additional if onnx has training mode...

I guess so.