PyWavelets / pywt

PyWavelets - Wavelet Transforms in Python
http://pywavelets.readthedocs.org
MIT License
2.04k stars 470 forks source link

Using dwt with Tensorflow and Keras #557

Closed rohang9 closed 2 years ago

rohang9 commented 4 years ago

So I'm working on a project with used 2-dimensional Discrete Wavelet Transform with a CNN. The idea is to remove the MaxPool layer and replace it with DWT and concatenate all the different outputs (low pass and high pass brach) into different channels. But it seems pywt.dwt accepts only numpy arrays as arguments instead of tensorflow tensors. I would like to use a dwt function so that it can be used with the tensors of the preceding CNN in the computation graph. My code somewhat looks like this (tf.numpy_function() throws error because of the string attribute for mentioning the used wavelet ('haar' for example))

from tensorflow.keras.layers import Conv2D,Input from tensorflow.keras.models import Model import pywt

def model_build(): X=Input(shape=shape) X1=Conv2D(64,(3,3),padding='same')(X) XL,XH=pywt.dwt2d(X1,'haar') X2=tf.concat(XL,XH[0],XH[1],XH[2],axis=-1) X3=Conv2D(64,(3,3),padding='same')(X2) XL_,XH_=pywt.dwt2d(X3,'haar') X4=tf.concat(XL_,XH_[0],XH_[1],XH_[2],axis=-1) #Above Process Repeated more times X5=Flatten()(X4) X6=Dense(1)(X5) model=Model(input=X,output=X6) return model

Any way of wrapping the available functions with tensorflow will really help. A workaround to get the dwt working with tensorflow tensors is also helpful Thanks!

nullkatar commented 4 years ago

Hey @rohang9 did you succeeded in doing this? I'm currently also working on that and will appreciate any bit of code! Thanks in advance!

rohang9 commented 4 years ago

Hey @nullkatar I've found a library which does this. It's not as well defined as pywt, but still it was sufficient for the problem I am working on. Check out the following link

https://github.com/UiO-CS/tf-wavelets

Cheers!

rgommers commented 2 years ago

pywt indeed only works with numpy arrays. One can transform a tensorflow tensor into a numpy array by using the numpy() method on the tensor. This will work for CPU tensors, and is zero-copy (i.e., fast). So to work around this issue, convert to a numpy array, use pywt, then convert back to a TF tensor.

There are no plans to make pywt work with TensorFlow, that'd be a lot of work. So I'll close this issue.