v0lta / Jax-Wavelet-Toolbox

Differentiable and gpu enabled fast wavelet transforms in JAX.
European Union Public License 1.2
40 stars 2 forks source link

dtype of input array to wavedec2 #6

Closed thoschm closed 1 year ago

thoschm commented 1 year ago

Hi @v0lta ,

the following call will give an exception for me:

jaxwt.wavedec2(data=jnp.zeros((10, 16, 16), dtype="float32"), wavelet=pywt.Wavelet("haar"))

Exception:

TypeError: lax.conv_general_dilated requires arguments to have the same dtypes, got float32, float64.

Are there plans to support other dtypes than float64?

Best, Thomas

v0lta commented 1 year ago

Dear @thoschm , the v0.0.8 branch addresses this now. After running:

pip install git+ssh://git@github.com/v0lta/Jax-Wavelet-Toolbox@v0.0.8

The image example below should work:

import pywt, scipy.datasets
import jaxwt as jwt
import jax.numpy as jnp
face = jnp.transpose(scipy.datasets.face(), [2, 0, 1])
face = face.astype(jnp.float32)
transformed = jwt.wavedec2(face, pywt.Wavelet("haar"))
jwt.waverec2(transformed, pywt.Wavelet("haar"))

All the best, Moritz

thoschm commented 1 year ago

Dear @v0lta,

this is amazing! Thanks for incorporating this into 0.0.8 👍

You're the best, Have a great day, Thomas