onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.27k stars 296 forks source link

10 times slower when converting to Tensorflow from PyTorch #473

Open hristo-vrigazov opened 5 years ago

hristo-vrigazov commented 5 years ago

I have a PyTorch model that is running at about 0.007 seconds on a 1080Ti - it's the PoseEstimationWithMobileNet from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/blob/master/models/with_mobilenet.py. When I converted the model to ONNX, the performance was really bad - about 0.120 in the exact same setup. I tried to use strict=False but the results did not change. I have attached a trace file from the Tensorflow session run.

ainrichman commented 5 years ago

If you use netron to visualize your exported pb file, you will find that the depthwise conv op is decomposed into a large set of regular conv ops. If your inference framework does not optimize that structure, you will get very low performance. This issus could not be resolved until pytorch itself supports native depthwise convs.

RomainSabathe commented 4 years ago

If you use netron to visualize your exported pb file, you will find that the depthwise conv op is decomposed into a large set of regular conv ops. If your inference framework does not optimize that structure, you will get very low performance. This issus could not be resolved until pytorch itself supports native depthwise convs.

I'm facing similar issues with a MobileNetV2 coming from Pytorch. The thing is, when using ONNX for inference it's fast. Same as with Pytorch. It's only when using the TF graph that inference is slow. I agree with you, I just don't think this is the root cause.

RomainSabathe commented 4 years ago

So indeed I used Netron to look at the network once converted to Tensorflow and was horrified to see hundreds and hundreds of conv ops. You were right. If anyone is reading this, here's how I went about fixing this. It's a hacky solution so I don't intend to do a PR or anything.

I modified the conv_mixin.py file.

from tensorflow.python.ops.nn_ops import _get_sequence

# rest of file

if transpose:
  # same as before
else:
  if group != weights.shape[-1]:
      convolved = [tf.nn.convolution(
          # as before
      ])
  else
     convolved = [
        tf.nn.depthwise_conv2d(
             x,
             tf.transpose(weights, [0, 1, 3, 2]),  # [filter_height, filter_width, in_channels, multiplier (=1)]
             strides=_get_sequence(strides, 2, channel_index=3, name="strides"),  # requires a 4-d list
             padding="VALID",
             rate=None,
             data_format=compute_format,
             dilations=dilations,
         )
     ]

And to cut extra fat on the graph, you can even avoid creating many splits in the case where we'll use a depthwise conv for sure (currently line 64):

if group != weights.shape[-1]:
  weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)
yokings commented 4 years ago

So indeed I used Netron to look at the network once converted to Tensorflow and was horrified to see hundreds and hundreds of conv ops. You were right. If anyone is reading this, here's how I went about fixing this. It's a hacky solution so I don't intend to do a PR or anything.

I modified the conv_mixin.py file.

from tensorflow.python.ops.nn_ops import _get_sequence

# rest of file

if transpose:
  # same as before
else:
  if group != weights.shape[-1]:
      convolved = [tf.nn.convolution(
          # as before
      ])
  else
     convolved = [
        tf.nn.depthwise_conv2d(
             x,
             tf.transpose(weights, [0, 1, 3, 2]),  # [filter_height, filter_width, in_channels, multiplier (=1)]
             strides=_get_sequence(strides, 2, channel_index=3, name="strides"),  # requires a 4-d list
             padding="VALID",
             rate=None,
             data_format=compute_format,
             dilations=dilations,
         )
     ]

And to cut extra fat on the graph, you can even avoid creating many splits in the case where we'll use a depthwise conv for sure (currently line 64):

if group != weights.shape[-1]:
  weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)

what is the func _get_sequence() implement?

andreydung commented 4 years ago

@RomainSabathe Could you please elaborate on this?

RomainSabathe commented 4 years ago

I invite you to read the function definition directly: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_ops.py#L62

Basically we use this function to format the "strides" argument depending on the dimension of the data (here we're dealing with images so n=2) and the channel index (could be 1 or 3, depending on if you use HxWxC formatting or CxHxW formatting). It also adds a batch dimension.

fmmohammadi commented 4 years ago
, 

That works, great! Thank you. I should add a point that if your input (input of network) is in form (BATCH, CHANNEL, HEIGHT, WIDTH) channel_index should be 1. Also, I think it is better to use pad_mode instead of "VALID".

BUT, It made the performance better but not same as PyTorch. Ex, in pytorch I got 164ms whereas in converted tensorflow I got 252ms! (which before depthwise correction I got 480 ms!) I hope there maybe be some solution to this.

tensorbuffer commented 4 years ago

@RomainSabathe I don't understand your weights.shape[-1], it's supposed to be channel. And weights should be in HWCM format according to https://github.com/onnx/onnx-tensorflow/blob/master/onnx_tf/handlers/backend/conv_mixin.py#L44, unless line 44 is same as line 42.

mgarbade commented 4 years ago

@RomainSabathe Where in conv_mixin.py should your code snippet be placed?

Just before the return statement (line 249)?

bmmlover commented 3 years ago

@fmmohammadi I have the same issue. With this fix inference of my network became two times faster, but still slower than in PyTorch. Have you found any solution?

tailtq commented 3 years ago

Thanks for the solution. By editing the conv_mixin.py file. The performance raises significantly. In my case, the prediction time reduces from ~0.06 to ~0.02 in CPU, but it's still two times slower than PyTorch and ONNX.