google-deepmind / tapnet

Tracking Any Point (TAP)
https://deepmind-tapir.github.io/blogpost.html
Apache License 2.0
1.28k stars 120 forks source link

OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input. #79

Open Cyril9227 opened 7 months ago

Cyril9227 commented 7 months ago

Hi everyone,

Thanks for the awesome work. I've been trying to export the pytorch model to ONNX for inference with torch.onnx.export but it yields this error : OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input.

Unfortunately, It seems 5D grid_sample is still unsupported by onnx / torch. Is there any alternative available ? Or any advice to make the model work with ONNX ?

Thanks

SergeySandler commented 7 months ago

@Cyril9227, torch.onnx.export() fails for me too. It seems like the cause is described in https://github.com/pytorch/pytorch/issues/100790 that will be addressed through https://github.com/pytorch/pytorch/issues/114801 (ONNX opset 20 support).

In the meantime I was trying to convert to ONNX through Haiku (JAX) -> TensorFlow ->ONNX, using https://dm-haiku.readthedocs.io/en/latest/notebooks/jax2tf.html as a tutorial for Haiku -> TF:

import functools
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tqdm import tqdm
import tree
from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils
import tensorflow as tf
import sonnet as snt

checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']
params_vars = tf.nest.map_structure(tf.Variable, params)

def build_online_model_init(frames, query_points):
  """Initialize query features for the query points."""
  model = tapir_model.TAPIR(use_causal_conv=True, bilinear_interp_with_depthwise_conv=False) 

  feature_grids = model.get_feature_grids(frames, is_training=False)
  query_features = model.get_query_features(
      frames,
      is_training=False,
      query_points=query_points,
      feature_grids=feature_grids,
  )
  return query_features

init_tf = hk.transform(build_online_model_init) 

class JaxModule(snt.Module):
  def __init__(self, params, apply_fn, name=None):
    super().__init__(name=name)
    self._params = params   
    self._apply = jax2tf.convert(lambda p, x: apply_fn(p, None, x), enable_xla=False)
    self._apply = tf.autograph.experimental.do_not_convert(self._apply)

  def __call__(self, inputs):
    return self._apply(self._params, inputs)

net = JaxModule(params_vars,  init_tf.apply)

# frames: [num_frames, height, width, 3], query_points: [num_points, 3] where 3 for the tuple (t, y, x)
@tf.function(autograph=False, input_signature=[{"frames" : tf.TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32), 
                                                "query_points": tf.TensorSpec(shape=(20,3), dtype=tf.float32)}]) 
def forward(x):
  return net(x)

to_save = tf.Module()
to_save.forward = forward
to_save.params = list(net.variables)
tf.saved_model.save(to_save, "TapirInit")  

but it fails with _TypeError: build_online_model_init() missing 1 required positional argument: 'querypoints'. Similar with _build_online_modelpredict(). Maybe the _inputsignature() is incorrect in tf.function(), but I cannot figure out how to fix it. Have you tried the TF path?

Since tf2onnx only supports ONNX opset up to 18, the TF SavedModel to ONNX conversion is likely to have the same problem as with PyTorch :(

saikiran321 commented 7 months ago

@Cyril9227 I have posted a solution here https://github.com/pytorch/pytorch/issues/100790. See if that works for you

SergeySandler commented 7 months ago

@saikiran321, the solution you have posted does not produce the unsupported ONNX error related to opset 20 support. Instead, torch.onnx.export fails with ValueError: only one element tensors can be converted to Python scalars. A docker file and a Python code to reproduce the result are in the zip file attached torch2onnx.zip. Do you know what could be the cause for this error? Thank you.

cdoersch commented 7 months ago

I'm no expert on ONNX, but if the problem is a 5D gather operation, then I suspect the source of the problem is extracting query features. It's possible to rewrite the vmap using a 4D gather; it wastes computation, but it's probably relatively small compared to the rest of the model. Try setting parallelize_query_extraction to True when contstructing the tapir model; it should produce exactly the same result given the same checkpoint, but hopefully it will avoid the 5D gather.

As a bit of an explanation, when extracting the query feature, we get a [t,y,x] coordinate and use bilinear interpolation to extract a feature from that location. The parallelize_query_extraction version instead extracts the feature at [y,x] from every frame (using a vmapped 4D gather), and then multiplies the resulting tensor by a 1-hot t vector to discard every query feature except the one on frame t.

Of course, this is only implemented the jax version; you'd have to re-implement the same algorithm in the torch model to export from torch.

zmtttt commented 4 months ago

hi! hi! I export opset16 -onnx,and use onnx_graphsurgeon to directly modify the opset to 20,then use trtexec --onnx xx—engine, meeting the same problem:Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addGridSample::1474, condition: input.getDimensions().nbDims == 4 @saikiran321 @SergeySandler @Cyril9227 @yotam

larrygoyeau commented 3 months ago

Hi

hi! hi! I export opset16 -onnx,and use onnx_graphsurgeon to directly modify the opset to 20,then use trtexec --onnx xx—engine, meeting the same problem:Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addGridSample::1474, condition: input.getDimensions().nbDims == 4 @saikiran321 @SergeySandler @Cyril9227 @yotam

Hi! Same error, did you succeed to solve this?

ibaiGorordo commented 1 month ago

I modified the torch model for the case of t=1 and reduced all the 5D to 4D, among other changes: https://github.com/ibaiGorordo/Tapir-Pytorch-Inference

I also added a script to export the model but it is very slow when running in onnxruntime compared to Pytorch (RTX4080): ~700 ms without refinement and ~20s with 4 iterations (1000 points 640x640)

SergeySandler commented 1 month ago

@ibaiGorordo,

it is very slow when running in onnxruntime compared to Pytorch (RTX4080)

Do you have the code for inference with ONNX? Do you use CUDA Execution Provider or CPU Execution Provider with ONNX?

ibaiGorordo commented 1 month ago

@ibaiGorordo,

it is very slow when running in onnxruntime compared to Pytorch (RTX4080)

Do you have the code for inference with ONNX? Do you use CUDA Execution Provider or CPU Execution Provider with ONNX?

I added the inference time calculation on the onnx_export.py script.

CPU is faster: tapir_onnx_cpu

Than CUDA: tapir_onnx_cuda

The slow part seems to be with the convolutions in the pips mixer block

SergeySandler commented 1 month ago

@ibaiGorordo, I have reproduced tapir.onnx and it is three times slower than Pytorch with CUDA device. My results on Windows: PyTorch inference takes around 0.1 sec on CUDA, 3 sec on CPU; ONNX - 0.3 sec with DmlExecutionProvider, 3 sec with CPUExecutionProvider.

There are a couple of hints for Windows that might be useful, especially if your results with ONNX are worse than with CPU:

  1. Do not forget to add _device_id:your_cardID (that is 0 in my case) in predictor = onnxruntime.InferenceSession(f'{output_dir}/tapir.onnx', providers = ['DmlExecutionProvider'], provider_options=[{'device_id':0}]) , otherwise it might use integreated Intel graphics card instead of NVIDIA card,
  2. Without pip install onnxruntime-directml DmlExecutionProvider is not available in Windows.