Closed tillahoffmann closed 6 months ago
Awesome work!! thanks, Till!
Thank you for the fast review and merging!
I just realized we probably want to add the option to unpack the complex coefficients with shape (..., n // 2 + 1)
for a signal of shape (..., n)
to a real tensor with shape (..., n)
because the rest of the library is designed for real-valued tensors. This would be equivalent to what we're doing in Stan here.
One option is to add an argument unpack
(or some more informative name) to the transform to indicate if real or complex values should be returned. Another is to add a separate unpacking transform. What do you think?
Interesting, I feel that it's better to have a separate transform for the unpack version.
That could look like
class RealFastFourierUnpackTransform(Transform):
"""
:param size: Size of the last dimension of the transform, required because the size
cannot be inferred from the shape of the coefficients.
"""
def __init__(self, size):
self.size = size
...
If we added it with an unpack
parameter, we could infer the size from the argument x
. I agree that having a separate transform seems like a good idea, but it would also be nice if one doesn't have to specify the shape ahead of time, e.g., if one wants to use the same transform object on different size input. Not sure which one is the better compromise.
Edit: Or did you mean having a transform (which includes unpacking) that inherits from the transform implemented in this PR?
This PR
_Complex
constraint,_Real
constraint to inherit from_Complex
and check that values are indeed real,RealFastFourierTransform
.The
RealFastFourierTransform
interface differs slightly fromjax.numpy.fft.rfftn
to respect that batch dimensions precede event dimensions. I.e., rather than specifying the axes along which to transform, one specifies the number of dimensions to transform akin toreinterpreted_batch_ndims
in theIndependentTransform
. The interface does not support specifying thenorm
parameter becausejit
-ing isn't happy with string arguments (cf. https://github.com/google/jax/issues/3045).The motivation for this PR is to implement fast Gaussian processes for stationary kernels (cf. section 4 in https://arxiv.org/pdf/2301.08836v4.pdf).
Thank you for taking the time to review my recent PRs. Let me know if they create too much noise in your inbox.