Closed thoschm closed 1 year ago
Dear @thoschm, the toolbox is meant for machine learning use cases, where the first dimension is typically a batch dimension. https://jax-wavelet-toolbox.readthedocs.io/en/latest/jaxwt.html#jaxwt.conv_fwt_2d.wavedec2 states the input shape should be [batch, height, width]. The following code snippet should work as expected:
import jax.numpy as jnp
import numpy as np
import jaxwt
import pywt
import matplotlib.pyplot as plt
arr_np = np.arange(16*16).reshape((16, 16)).astype("float64")
arr_jnp = jnp.expand_dims(jnp.array(arr_np), 0)
coeffs_pywt = pywt.wavedec2(arr_np, pywt.Wavelet("db4"), mode='zero')
coeffs_jaxwt = jaxwt.wavedec2(arr_jnp, pywt.Wavelet("db4"), mode='zero')
plt.plot(np.concatenate(jaxwt.utils.flatten_2d_coeff_lst(coeffs_pywt)))
plt.plot(np.concatenate(jaxwt.utils.flatten_2d_coeff_lst(coeffs_jaxwt)), ".")
plt.show()
I agree your example should also work. I will add better shape checks to the 0.0.8 TODO list.
All the best, Moritz
Dear Moritz,
thanks for your quick reply and for clearing this up. This is perfect, batch support is exactly what I need.
Have a great day, Thomas
PS: thanks for the example as well
Hi,
first of all, let me say that I really appreciate your project, amazing work!
I need some advice on the following code that - at least for me - is producing different results:
It seems that jaxwt is doing 0 levels in this case and pywt more.
Best + Thanks for advice, Thomas