v0lta / Jax-Wavelet-Toolbox

Differentiable and gpu enabled fast wavelet transforms in JAX.
European Union Public License 1.2
40 stars 2 forks source link

Advice on results needed. #5

Closed thoschm closed 1 year ago

thoschm commented 1 year ago

Hi,

first of all, let me say that I really appreciate your project, amazing work!

I need some advice on the following code that - at least for me - is producing different results:

import jax.numpy as jnp
import numpy as np
import jaxwt
import pywt

arr_np = np.arange(16*16).reshape((16, 16)).astype("float32")
arr_jnp = jnp.array(arr_np)

coeffs_pywt = pywt.wavedec2(arr_np, pywt.Wavelet("db4"))
coeffs_jaxwt = jaxwt.wavedec2(arr_jnp, pywt.Wavelet("db4"))

print(coeffs_pywt)
print(coeffs_jaxwt)

It seems that jaxwt is doing 0 levels in this case and pywt more.

Best + Thanks for advice, Thomas

v0lta commented 1 year ago

Dear @thoschm, the toolbox is meant for machine learning use cases, where the first dimension is typically a batch dimension. https://jax-wavelet-toolbox.readthedocs.io/en/latest/jaxwt.html#jaxwt.conv_fwt_2d.wavedec2 states the input shape should be [batch, height, width]. The following code snippet should work as expected:

import jax.numpy as jnp
import numpy as np
import jaxwt
import pywt
import matplotlib.pyplot as plt

arr_np = np.arange(16*16).reshape((16, 16)).astype("float64")
arr_jnp = jnp.expand_dims(jnp.array(arr_np), 0)

coeffs_pywt = pywt.wavedec2(arr_np, pywt.Wavelet("db4"), mode='zero')
coeffs_jaxwt = jaxwt.wavedec2(arr_jnp, pywt.Wavelet("db4"), mode='zero')

plt.plot(np.concatenate(jaxwt.utils.flatten_2d_coeff_lst(coeffs_pywt)))
plt.plot(np.concatenate(jaxwt.utils.flatten_2d_coeff_lst(coeffs_jaxwt)), ".")
plt.show()

I agree your example should also work. I will add better shape checks to the 0.0.8 TODO list.

All the best, Moritz

thoschm commented 1 year ago

Dear Moritz,

thanks for your quick reply and for clearing this up. This is perfect, batch support is exactly what I need.

Have a great day, Thomas

PS: thanks for the example as well