asteroid-team / asteroid-filterbanks

Asteroid's filterbanks :rocket:
https://asteroid-team.github.io/
MIT License
80 stars 20 forks source link

script_if_tracing breaks onnxruntime #15

Closed lminer closed 2 years ago

lminer commented 2 years ago

It appears that the onnxruntime does not like the script_if_tracing decorator introduced for backwards compatibility. If you remove the decorator from here, everything works fine. However, with the decorator included, we get a reshape error.

If the decorator is only needed for torch 1.6.0 support, maybe it should only be used for that version of torch?

Here's a reproducible example that you can run on colab:

%pip install asteroid_filterbanks
%pip install onnx
%pip install onnxruntime
%pip install torch

from asteroid_filterbanks.enc_dec import Encoder
from asteroid_filterbanks import torch_stft_fb
import numpy as np
import torch
import torch.onnx
import onnxruntime as ort
import numpy as np

window = np.hanning(512 + 1)[:-1] ** 0.5
fb = torch_stft_fb.TorchSTFTFB(
    n_filters=512,
    kernel_size=512,
    center=True,
    stride=256,
    window=window
)
encoder = Encoder(fb)

nb_samples = 1
nb_channels = 2
nb_timesteps = 11111
example = torch.rand((nb_samples, nb_channels, nb_timesteps))

out = encoder(example)

torch.onnx.export(
    encoder,
    example,
    "test.onnx",
    export_params=True,
    opset_version=16,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    verbose=True,
)

ort_sess = ort.InferenceSession("test.onnx")
outputs = ort_sess.run(None, {'input': example.numpy()})

Produces:

---------------------------------------------------------------------------

RuntimeException                          Traceback (most recent call last)

[<ipython-input-29-acc756eada00>](https://localhost:8080/#) in <module>()
     37 
     38 ort_sess = ort.InferenceSession("test.onnx")
---> 39 outputs = ort_sess.run(None, {'input': example.numpy()})

[/usr/local/lib/python3.7/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py](https://localhost:8080/#) in run(self, output_names, input_feed, run_options)
    198             output_names = [output.name for output in self._outputs_meta]
    199         try:
--> 200             return self._sess.run(output_names, input_feed, run_options)
    201         except C.EPFail as err:
    202             if self._enable_fallback:

RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_115' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,2,44}, requested shape:{1,1,514,44}
mpariente commented 2 years ago

Thanks for the issue @lminer

I don't think it's for backward compatibility reasons that we keep this, but to be able to export filterbanks using torch.trace.

Do the tests run if you remove the decorator ?

lminer commented 2 years ago

I haven't been able to run the tests, but I have been able to run torch.jit.trace without an issue.

mpariente commented 2 years ago

So, the tests are failing without it. We want the encoder to work in the same way for all input dimensions, that's why we use script_if_tracing.

What do you suggest doing ?

lminer commented 2 years ago

I think I’m not entirely clear what ‘script_if_tracing’ does and how it helps with shapes. What if we put a flag in the encoder constructor that defaults to true, but that can be disabled?

On Fri, Jul 29, 2022 at 9:02 PM Pariente Manuel @.***> wrote:

So, the tests are failing without it. We want the encoder to work in the same way for all input dimensions, that's why we use script_if_tracing.

What do you suggest doing ?

— Reply to this email directly, view it on GitHub https://github.com/asteroid-team/asteroid-filterbanks/issues/15#issuecomment-1199860593, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABHDTFIQBP5O42HOP7DZSTDVWQTFZANCNFSM55ATKLFQ . You are receiving this because you were mentioned.Message ID: @.***>