asteroid-team / asteroid-filterbanks

Asteroid's filterbanks :rocket:
https://asteroid-team.github.io/
MIT License
80 stars 20 forks source link

Support ONNX export of selected filterbank modules #10

Open faroit opened 3 years ago

faroit commented 3 years ago

one of the benefits of 1d conv based filterbanks is that they can be more easily exported for deployment.

testing TorchSTFTFB reveals that onnx export doesn't currently work and its not clear where the error stems from due to this.

example of traced module of the encoder exported with onnx:

    import torch.onnx
    from asteroid_filterbanks.enc_dec import Encoder
    from asteroid_filterbanks import torch_stft_fb

    nb_samples = 1
    nb_channels = 2
    nb_timesteps = 11111

    example = torch.rand((nb_samples, nb_channels, nb_timesteps))

    fb = torch_stft_fb.TorchSTFTFB(n_filters=512, kernel_size=512)
    enc = Encoder(fb)
    torch_out = enc(example)
    # Export the model
    torch.onnx.export(
        enc,
        example,
        "umx.onnx",
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        verbose=True
    )

results in

Traceback (most recent call last):
  File "onnx.py", line 28, in <module>
    verbose=False
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 230, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 91, in export
    use_external_data_format=use_external_data_format)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 639, in _export
    dynamic_axes=dynamic_axes)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 421, in _model_to_graph
    dynamic_axes=dynamic_axes, input_names=input_names)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 203, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 968, in _run_symbolic_function
    torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 979, in _run_symbolic_function
    operator_export_type)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 888, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "/Users/faro/repositories/open-unmix-pytorch/env-cpu/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 111, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator prim_Uninitialized to ONNX opset version 10 is not supported. Please open a bug to request ONNX export support for the missing operator.
mpariente commented 3 years ago

Thanks for the issue. Did you try the other filterbanks?

faroit commented 3 years ago

Yes, not much luck:

mpariente commented 3 years ago

Ok.. Not sure we have the bandwidth to sort this out without outside help, if you'd like to work on this, please do!

faroit commented 3 years ago

Not sure we have the bandwidth to sort this out without outside help, if you'd like to work on this, please do!

me, neither. I was just super excited to get a fb frontend working with onnxruntime for a desktop app demo. But for now we better watch pytorch upstream and re-check when the error messages get more precise. I would suggest to leave this open as a reminder.

jonashaag commented 3 years ago

I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch.

mpariente commented 3 years ago

I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch.

Then, problem arises when you want to have the iSTFT in the network, for time domain losses, right?

mpariente commented 3 years ago

Maybe it's actually not that complicated to fix this. We could make a new class, as simple as possible, that just makes the conv or transposed conv, with fixed filters. And we'd make an export method from Encoder and Decoder.

mpariente commented 3 years ago

@faroit have you tried with simple nn.Conv1D?

jonashaag commented 3 years ago

Then, problem arises when you want to have the iSTFT in the network, for time domain losses, right?

Yes but if you put the loss outside the model, no problem. I don't need the loss in the ONNX export.

faroit commented 3 years ago

I recently had this problem too and simply switched to SciPy based STFT, ie. not having the filterbank in PyTorch.

how did you export the scipy stft to onnx?

jonashaag commented 3 years ago

I didn't, it is required to be installed at the ONNX export user.

faroit commented 3 years ago

I didn't, it is required to be installed at the ONNX export user.

@jonashaag okay, that wasn't the use-case I had in mind. Having the full end-to-end model in onnx gives you the flexibility to perform audio processing e.g. from node.js/electron without having to reimplement the pre/post pipeline in js. Python would not be an option in that case.

faroit commented 3 years ago

@faroit have you tried with simple nn.Conv1D?

yes, that works. There are also other STFT variants that can be exported. I guess its a trivial thing but maybe we won't be able to track this down without much effort until the error tracing improves

faroit commented 3 years ago

nnAudios implmentation seems to be onnx exportable. We might want to check the differences... https://github.com/KinWaiCheuk/nnAudio/issues/23#issuecomment-768954731

KinWaiCheuk commented 3 years ago

nnAudios implmentation seems to be onnx exportable. We might want to check the differences... KinWaiCheuk/nnAudio#23 (comment)

I used two nn.Conv1d to write my STFT class in nnAudio. One Conv1d is for the real part (cos kernels) another Conv1d is for the imaginary part (sin kernels). May I know how do you implement STFT in your asteroid-filterbanks?

mpariente commented 3 years ago

Using the functional API. Have a look at the Encoder.forward.

mpariente commented 3 years ago

I can convert the free, param and STFT with pytorch nightly, but not the torch_stft version yet. Regarding analytic filterbank, there is some hope here

mpariente commented 3 years ago

I check all the hooks, the is only one that doesn't pass is pre_analysis, which does the padding. This is the function.

Now that this is much narrower, Fabian, would you like to have a look?

faroit commented 3 years ago

@mpariente next issue, to address the decoder, torch.fold is not supported....

https://github.com/asteroid-team/asteroid-filterbanks/blob/351029209556bdccf6d27edbebfc6c8988c663ff/asteroid_filterbanks/torch_stft_fb.py#L172-L176

whats good replacement?

mpariente commented 3 years ago

Are you sure about that? Replacing it will be very cumbersome

faroit commented 3 years ago

Are you sure about that? Replacing it will be very cumbersome

seems so: https://github.com/pytorch/pytorch/issues/41423

mpariente commented 3 years ago

Did you try with the function API? It's probably the same but worth a try.

Le ven. 12 mars 2021 à 15:08, Fabian-Robert Stöter @.***> a écrit :

Are you sure about that? Replacing it will be very cumbersome

seems so: pytorch/pytorch#41423 https://github.com/pytorch/pytorch/issues/41423

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/asteroid-team/asteroid-filterbanks/issues/10#issuecomment-797512122, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEND2HD7GPRVPEM6XEG7NCLTDIN47ANCNFSM4WOJD4HQ .

faroit commented 2 years ago

@mpariente @jonashaag works with torch 12 and opset > 11 now!

Should I still ad some tests?

mpariente commented 2 years ago

Cool, thanks Fabian !

If you could, that'd be great !