jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.19k stars 2.77k forks source link

convolve2D not supporting "boundary=wrap" #7276

Open romanodev opened 3 years ago

romanodev commented 3 years ago

Hi JAX developers,

I am trying to filter periodic images with jax.scipy.signal.convolve2d, but it seems that the flag boundary='wrap' is not supported, although it was mentioned in the doc.

Here is a MVE:

from jax.scipy.signal import convolve2d
import jax.numpy as jnp

a1 = jnp.ones((3,3))

a = convolve2d(a1, a1,mode='same',boundary='wrap')

Here is the error: raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0") NotImplementedError: convolve2d() only supports boundary='fill', fillvalue=0

Thanks!

jakevdp commented 3 years ago

Hi - thanks for the report! This is a known issue, though is not tracked anywhere (aside from the explicit NotImplementedError you found in the code).

Regarding the documentation, the reason it's mentioned there is because the docstring is copied verbatim from scipy.signal.convolve, as is the custom with JAX wrapped numpy/scipy functions. In many cases if you dig hard enough you'll find unimplemented keywords like this one.

I'm going to change this from bug to enhancement, because it's something we know is unimplemented and we hope will be implemented by a team member or community member in the future.

Thanks!

jakevdp commented 3 years ago

One more thing: the reason this has not yet been implemented is because convolutions are computed via XLA's ConvWithGeneralPadding, and I'm not certain whether it is able to compute the equivalent of scipy's wrapped convolutions.

hawkinsp commented 3 years ago

The best way to implement this is probably to explicitly form the padding values as part of the input to the convolution. Note that jnp.pad supports wrapping padding modes, so perhaps the implementation is as simple as composing the two?

romanodev commented 3 years ago

@hawkinsp , great call.

Here is a working example:

from jax.scipy.signal import convolve2d as convolve_jax
from scipy.signal     import convolve2d as convolve_scipy
from jax import random
import jax.numpy as jnp

key = random.PRNGKey(0)

def convolve_wrap(a1,a2):

    N = a1.shape[0]
    a1 = jnp.pad(a1,N,mode='wrap')

    return convolve_jax(a1, a2,mode='same')[N:2*N,N:2*N]

#Filter
N = 5
a2 = jnp.ones((N,N))/N/N

#Image
N = 10
a1 = random.normal(key, (N,N))

#Scipy
a_scipy = convolve_scipy(a1, a2,mode='same',boundary='wrap')

#Jax
a_jax = convolve_wrap(a1,a2)

print(jnp.allclose(a_scipy,a_jax,a_tol=1e-6))
hawkinsp commented 3 years ago

@romanodev Awesome! Would you be interested in contributing a PR that adds support for boundary="wrap" to convolve2d? The implementation is here: https://github.com/google/jax/blob/9450b8f8f96cb4d57b3f6337dcc18d3d104ecf6b/jax/_src/scipy/signal.py#L73

romanodev commented 3 years ago

@hawkinsp , I can definitely take a look at that. While the above case is for mode='full', I guess it must be implemented for all modes, as required in https://github.com/google/jax/blob/main/tests/scipy_signal_test.py#L50

jakevdp commented 3 years ago

Hi @romanodev - don't worry about implementing everything at once if it's blocking you. Even just implementing mode='full' would be useful, and you could adjust the tests to skip unimplemented combinations.

romanodev commented 3 years ago

@jakevdp, great! I will resume this next week.

rajasekharporeddy commented 4 months ago

Hi @romanodev

A fix for this issue is included in pull request #21241.

romanodev commented 4 months ago

@rajasekharporeddy , awesome!