astro-informatics / s2wav

Differentiable and accelerated wavelet transform on the sphere with JAX
https://astro-informatics.github.io/s2wav/
MIT License
14 stars 0 forks source link

wavelet transform demo error #73

Closed flying-gwx closed 1 year ago

flying-gwx commented 1 year ago

Hello,

I want to try the s2wav lib according to https://github.com/astro-informatics/s2let/issues/50#issuecomment-1553105667. When I try the code demo, I found an error occurs in s2wav/transforms/jax_wavelets.py", line 260. It seams like that the filters is needful to perform wavelet transform(but it should be optional ). Did I do anything wrong?

Here is my test code:

import s2wav
import numpy as np
L = 128
N = 1
f = np.ones((L, 2*L-1))
f_wav, f_scal = s2wav.analysis(f, L, N)
f = s2wav.synthesis(f_wav, f_scal, L, N)

Here is the error:

Traceback (most recent call last):
  File "test_wav.py", line 7, in <module>
    f_wav, f_scal = s2wav.analysis(f, L, N)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _, jaxpr = infer_params_fn(
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/api.py", line 300, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 499, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 961, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun
    ans = call(fun, *args)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/pjit.py", line 914, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/s2wav/transforms/jax_wavelets.py", line 260, in analysis
    jnp.conj(filters[0]),
jax._src.traceback_util.UnfilteredStackTrace: TypeError: 'NoneType' object is not subscriptable

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test_wav.py", line 7, in <module>
    f_wav, f_scal = s2wav.analysis(f, L, N)
  File "/home/gaowenxuan/anaconda3/envs/py38/lib/python3.8/site-packages/s2wav/transforms/jax_wavelets.py", line 260, in analysis
    jnp.conj(filters[0]),
TypeError: 'NoneType' object is not subscriptable

I have used the pytest to test s2wav and I pass the test.

========================================================== test session starts ===========================================================
platform linux -- Python 3.8.16, pytest-7.3.1, pluggy-1.0.0
rootdir: /home/gaowenxuan/Code/s2wav
configfile: pytest.ini
collected 288 items                                                                                                                      

tests/test_filters.py ........................................                                                                     [ 13%]
tests/test_gradients.py ........                                                                                                   [ 16%]
tests/test_wavelets.py ..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss                            [ 44%]
tests/test_wavelets_base.py ..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss.. [ 79%]
ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss..ss                                                                         [100%]

============================================== 168 passed, 120 skipped in 447.66s (0:07:27) ==============================================

Thanks.

jasonmcewen commented 1 year ago

@CosmoMatt @alicjapolanska @JessWhitney Could someone please get back to @flying-gwx about this?

CosmoMatt commented 1 year ago

Hi @flying-gwx, I think this is a little oversight on our behalf (this code is still under development). Basically this line in the docstrings indicates that filters is an optional argument. This was our intention as, if the user doesn't provide custom filters then a simple statement will catch that no filters have been provided and we will construct some default.

However, we have not yet implemented JAX functions to generate the filters, so when no filters are provided the code ends up trying to convolve your signal with none (!). For the time being you'll want to run something like this

import s2wav
import numpy as np
L = 128
N = 1
f = np.ones((L, 2*L-1))

filter = filters.filters_directional_vectorised(L, N)

f_wav, f_scal = s2wav.analysis(f, L, N, filters=filter)
f = s2wav.synthesis(f_wav, f_scal, L, N)

which should work I think. When I get a chance I'll add a Warning for this, and we'll try and add a catch to automatically generate the filters where needed.

Sorry for any confusion here!

flying-gwx commented 1 year ago

Thanks for your reply. I had tested the script and it worked. By the way, Here is the code(with a little difference).

import s2wav
import numpy as np
import s2wav.filter_factory.filters as filters
L = 128
N = 1
f = np.ones((L, 2*L-1))
# Compute wavelet coefficients
filter = filters.filters_directional_vectorised(L, N)
f_wav, f_scal = s2wav.analysis(f, L, N, filters=filter, reality=True)

# Map back to signal on the sphere 
f_build = s2wav.synthesis(f_wav, f_scal, L, N, filters=filter, reality=True)
print(np.abs(f_build -f).mean())
CosmoMatt commented 1 year ago

Thanks @flying-gwx, I'll open a separate issue for automatic filter computation now. If you run into any further definitely let the team know!