fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

export to onnx #30

Open carr123 opened 3 years ago

carr123 commented 3 years ago

I build a neural network using ScatLayer as one layer. when I export the model from pytorch to onnx format, error occurs.

pytorch_wavelets\utils.py", line 162, in reflect out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx TypeError: '>=' not supported between instances of 'numpy.ndarray' and 'Tensor'

fbcotter commented 3 years ago

Ah bummer. Yeah I remember I had to do some funky things to get padding to work in the right way. It's several months later I know, but what command did you use to export to onnx? It might be possible to do the padding with torch tensors so the export works fine.