asteroid-team / asteroid-filterbanks

Asteroid's filterbanks :rocket:
https://asteroid-team.github.io/
MIT License
80 stars 20 forks source link

Make torch_stft_fb.py encoder onnxable #11

Closed mpariente closed 3 years ago

mpariente commented 3 years ago

Ref #10 The fix was quite simple, but it always takes me time to find scriptable and onnxable tricks.

mpariente commented 3 years ago

And I cannot make the Decoder onnxable though. If you want to give it a try @faroit, feel free!

faroit commented 3 years ago

Ref #10 The fix was quite simple, but it always takes me time to find scriptable and onnxable tricks.

cool, how did you find out that this causes problems?

faroit commented 3 years ago

@mpariente shall we add separate onnx tests for encoder and decoder, thus making at least something pass for this PR? I can then work on the decoder

mpariente commented 3 years ago

cool, how did you find out that this causes problems?

By elimination.

@mpariente shall we add separate onnx tests for encoder and decoder, thus making at least something pass for this PR? I can then work on the decoder

Yes, it would be nice to add tests. I think they only work with nightly though, I'd have to check. Do you want to implement them?

faroit commented 3 years ago

@mpariente good news, with your fix + torch 1.8.0 + setting opset to 11, both the decoder and encoder passes the onnx export. We could go lower than opset 11 (which would allow better conversion to tensorflow) when the .squeeze(0) would be replaced...

UserWarning: This model contains a squeeze operation on dimension 0. If the model is intend ed to be used with dynamic input shapes, please use opset version 11 to export the model.

I any case, i think separate tests should be added in a separate PR, i can do that. So I would suggest to merge this first

mpariente commented 3 years ago

I'd love tests, thanks !