pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Add complex constraint and real Fourier transform. #1762

Closed tillahoffmann closed 6 months ago

tillahoffmann commented 6 months ago

This PR

The RealFastFourierTransform interface differs slightly from jax.numpy.fft.rfftn to respect that batch dimensions precede event dimensions. I.e., rather than specifying the axes along which to transform, one specifies the number of dimensions to transform akin to reinterpreted_batch_ndims in the IndependentTransform. The interface does not support specifying the norm parameter because jit-ing isn't happy with string arguments (cf. https://github.com/google/jax/issues/3045).

The motivation for this PR is to implement fast Gaussian processes for stationary kernels (cf. section 4 in https://arxiv.org/pdf/2301.08836v4.pdf).

Thank you for taking the time to review my recent PRs. Let me know if they create too much noise in your inbox.

fehiepsi commented 6 months ago

Awesome work!! thanks, Till!

tillahoffmann commented 6 months ago

Thank you for the fast review and merging!

I just realized we probably want to add the option to unpack the complex coefficients with shape (..., n // 2 + 1) for a signal of shape (..., n) to a real tensor with shape (..., n) because the rest of the library is designed for real-valued tensors. This would be equivalent to what we're doing in Stan here.

One option is to add an argument unpack (or some more informative name) to the transform to indicate if real or complex values should be returned. Another is to add a separate unpacking transform. What do you think?

fehiepsi commented 6 months ago

Interesting, I feel that it's better to have a separate transform for the unpack version.

tillahoffmann commented 6 months ago

That could look like

class RealFastFourierUnpackTransform(Transform):
    """
    :param size: Size of the last dimension of the transform, required because the size
        cannot be inferred from the shape of the coefficients.
    """
    def __init__(self, size):
        self.size = size

    ...

If we added it with an unpack parameter, we could infer the size from the argument x. I agree that having a separate transform seems like a good idea, but it would also be nice if one doesn't have to specify the shape ahead of time, e.g., if one wants to use the same transform object on different size input. Not sure which one is the better compromise.

Edit: Or did you mean having a transform (which includes unpacking) that inherits from the transform implemented in this PR?