Open zlqhem opened 4 years ago
After following onnx-tensorflow, it seems reasonable for converting BatchNomalization node as like above picture. (4 Reshape, 1 Transponse and so on). 4 Reshape corresponds with mean, variance, bias, scale for handling batch normalization.
Question #1. Why not use BatchNomalization in tensorflow directly? Are there any issues? (Sorry, I'm new to both tensorflow and onnx-tensorflow)
But I wonder why converted Conv layer is so complex.
Below is the conv layer in onnx which is converted from mobilenet v2 in pytorch model.
Below is the conv layer converted using onnx-tensorflow.
If I'm something wrong, please let me know.
It seems already mentioned in #473.
from https://github.com/onnx/onnx-tensorflow/issues/473#issuecomment-533522917
... converted to Tensorflow and was horrified to see hundreds and hundreds of conv ops.
@chinhuang007 is there any info on this one? I am facing the same issue when one conv2d(not depthwise) with output of 32 filters turns into split -> 32 covn2d operators with output of 1 filter -> concatenate. Tested it on current master and it still reproduces and generates an enormous graph Are there at least any ways to modify the pytorch code to avoid generation of all this additional conv2d?
I did some additional digging. First I was wrong with saying that my conv2d wasn't depthwise, as it really was(conv2d is depthwise in pytorch and onnx if it has groups
parameter > 1)
In #473 there is a fix to join multiple convs on your screenshot. You can apply it to remove them
In master branch contributor tries to also fix it but there is an error on line
https://github.com/onnx/onnx-tensorflow/blob/6685f45f3fed37a7d8868c4ecabe8390e489a67c/onnx_tf/handlers/backend/conv_mixin.py#L96-L98
group == x_shape[1]
is a tensor and not a bool. Thus the whole depthwise
paramter is never True and the code is never executed for the if statement
To fix it you can just remove this comparison from depthwise computation
i.e. change the line to
depthwise = (x_rank == 4 and len(weight_shape) == 4 and
group != 1 and not transpose and
not (None in weight_shape))
This allows the creation of depthwise conv on current master branch
I've created a fork with the fix above, so you can do pip install git+https://github.com/Vozf/onnx-tensorflow
to apply it
@chinhuang007 @seanshpark please note that your fix of conv_mixin in master contains a bug as group == x_shape[1]
is a tensor and forces depthwise is True
to always be False
For fix in v1.7.0 go to #473
@Vozf Thanks for pointing it out! I believe group == x_shape[1] could be a bool, not always a tensor, when the input shape is known, as seen in https://github.com/onnx/onnx-tensorflow/blob/master/onnx_tf/common/tf_helper.py#L17. I agree x_shape[1] would be a tensor when the input shape is unknown so the condition won't work. I will create a quick patch as you described.
@chinhuang007 Glad to help you.
Could you also clarify the current status full onnx conversion to tensorflow channels_last to avoid all the Transpose
and Reshape
operators seen on the screenshot? Is it possible?
There was a pull request which got reverted with that functionality, but it seems really useful
@Vozf I can't completely remove group == x_shape[1] or it will fail ONNX standard backend tests with known input shape. So, I just created a PR, https://github.com/onnx/onnx-tensorflow/pull/885, to check the input shape is known or not. Frankly I don't have a test case to verify whether it will work for unknown shape. So... @seanshpark Would you please double check the logic is still working since you originally implemented the depthwise piece? Thanks!
@chinhuang007 , #885 fails to convert to DWConv for the models I've been working... I'll try to find how to fix for this model and mine too...
--> It was my environment problem :) The patch is working fine! @Vozf, thank you!
@chinhuang007 Glad to help you. Could you also clarify the current status full onnx conversion to tensorflow channels_last to avoid all the
Transpose
andReshape
operators seen on the screenshot? Is it possible? There was a pull request which got reverted with that functionality, but it seems really useful
@Vozf did you find a way to convert PyTorch ONNX models to TF with channels_last to avoid all the Transpose
and Reshape
operators?
Yeah kind of, https://github.com/gmalivenko/onnx2keras has the channels_last flag
Reproduce details: https://github.com/zlqhem/example-pytorch2tf/blob/master/pytorch2tf.ipynb
pytorch -> ONNX
ONNX -> tensorflow
it seems unexpected huge model.
Python, ONNX, ONNX-TF, Tensorflow version
Python: 3.7 ONNX: 1.7.0 ONNX_TF: 1.6.0 (source build a few days ago) TensorFlow: 2.3.0