AxisCommunications / onnx-to-keras

Convert onnx models exported from pytorch to tensorflow keras models with focus on performace and highleve compatibility.
MIT License
25 stars 13 forks source link

Batch Norm: Allow 2-dim and 3-dim BN #23

Closed xsacha closed 3 years ago

xsacha commented 3 years ago

I hit a batch norm with shape: (1, 512) that was hitting this. After removing the NotImplemented exception, it converted fine. Is there a reason it would not support 2 and 3 dimension batch normalisation?

image

hakanardo commented 3 years ago

It's because we've not verified that pytorch and tensorflow behaves in the same way in those cases. There are subtile differences between the two framwroks that can bite you if your not carefull. But adding a few testcases should resolv that.

xsacha commented 3 years ago

The BatchNorm1d test itself passes. However, there are checks that the test inputs must be 4-dimensional and it fails here: File "~/onnx-to-keras/onnx2keras.py", line 105, in make_input assert len(shape) == 4

        if len(shape) == 4:
            tensor = tf.keras.layers.Input((shape[2], shape[3], shape[1]), shape[0], name, dtype)
            tensor.data_format = InterleavedImageBatch
        else:
            tensor = tf.keras.layers.Input((shape[2], shape[1]), shape[0], name, dtype)
            tensor.data_format = InterleavedImageBatch # Other parts fail otherwise

There was also other parts that expect it to be an ImageInterleavedBatch even though it's 3-dimensions.

Also, the testing code expected it to be 4 dimension input.

I managed to get around these restrictions by starting with a 4-dim input and setting image_out=False like this:

        bn1 = torch.nn.BatchNorm1d(3)
        bn1.running_mean.uniform_()
        bn1.running_var.uniform_()
        net1 = torch.nn.Sequential(GlobalAvgPool(), bn1, torch.nn.ReLU())
        net1.eval()
        y = np.random.rand(1, 3, 224, 224).astype(np.float32)
        convert_and_compare_output(net1, y, image_out=False)
hakanardo commented 3 years ago

Yes, the test framework (among other things) are pretty focused on image processing. It could be imporved...