asteroid-team / asteroid-filterbanks

Asteroid's filterbanks :rocket:
https://asteroid-team.github.io/
MIT License
80 stars 20 forks source link

[src] Allow disabling script_if_tracing for ONNX export #16

Closed mpariente closed 2 years ago

mpariente commented 2 years ago

Fixes #15

But the exported ONNX model can only process the input shape that was passed to it for the conversion.

Changing the original script to this solves the problem.


from asteroid_filterbanks.enc_dec import Encoder
from asteroid_filterbanks import torch_stft_fb
import numpy as np
import torch
import torch.onnx
import onnxruntime as ort
import numpy as np
asteroid_filterbanks.scripting import disable_script_if_tracing

disable_script_if_tracing()

window = np.hanning(512 + 1)[:-1] ** 0.5
fb = torch_stft_fb.TorchSTFTFB(
    n_filters=512,
    kernel_size=512,
    center=True,
    stride=256,
    window=window
)
encoder = Encoder(fb)

nb_samples = 1
nb_channels = 2
nb_timesteps = 11111
example = torch.rand((nb_samples, nb_channels, nb_timesteps))

out = encoder(example)

torch.onnx.export(
    encoder,
    example,
    "test.onnx",
    export_params=True,
    opset_version=16,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    verbose=True,
)

ort_sess = ort.InferenceSession("test.onnx")
outputs = ort_sess.run(None, {'input': example.numpy()})

@lminer, would you like to make a PR with consistency tests with ONNX exports?