Open brianwa84 opened 4 years ago
We would probably need rfft
and irfft
first. I made a little bit of progress in https://github.com/google/jax/pull/1657, but got stuck on the transpose rules. It looks those are indeed somewhat non-trivial:
https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/signal/fft_ops.py#L218
If you assume an even-sized dimension, here are impls for i/rfft:
def rfft(x): x = tf.convert_to_tensor(x, dtype=tf.float32) return fft(tf.cast(x, tf.complex64))[..., :tf.shape(x)[-1]//2+1] def irfft(x,n=None): x = tf.concat([x, tf.math.conj(x[..., ::-1][..., 1:-1])], axis=-1) return tf.math.real(ifft(x)) Brian Patton | Software Engineer | bjp@google.com
On Tue, Dec 10, 2019 at 6:37 PM Stephan Hoyer notifications@github.com wrote:
We would probably need rfft and irfft first. I made a little bit of progress in #1657 https://github.com/google/jax/pull/1657, but got stuck on the transpose rules. It looks those are indeed somewhat non-trivial:
https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/signal/fft_ops.py#L218
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI2MLSO23SH7EKAUDALQYAR4LA5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGRKZUA#issuecomment-564309200, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSIYPIHUEAYLNT5735LDQYAR4LANCNFSM4JZERKVA .
I was thinking that there is no point in adding rfft without the efficiency gain, but it's a fair point that some users might just want them for the API convenience.
Hmm, do we have any further thoughts on this?
Should we target an efficient rfft implementation or just implement a np.real()
call around the complex fft
routines?
Would we be adding significant inertia against working on the former by implementing the latter?
I think it would be totally reasonable to start with the easy wrapper approach based on real(fft(x))
.
Conceivably we could even still use the XLA rfft
functions for cases where we don't need a transpose
(we could write only the transpose
rule in terms of the complex FFTs). Actually arguably this could have better performance than the gradient of rfft
in TensorFlow, which seems to fall back to using matrix-multiplication.
What is the transpose you're talking about? Why not always do rfft and irfft in xla directly?
On Fri, Jan 10, 2020, 6:50 PM Stephan Hoyer notifications@github.com wrote:
I think it would be totally reasonable to start with the easy wrapper approach based on real(fft(x)).
Conceivably we could even still use the XLA rfft functions for cases where we don't need a transpose (we could write only the transpose rule in terms of the complex FFTs)
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI5LHA6GVPTB2EJUO33Q5ECTVA5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIVRR3A#issuecomment-573249772, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7SDAXN57AQKOVVG7TQ5ECTVANCNFSM4JZERKVA .
Xla fft takes an fft_type which can be any of fft, ifft, rfft, or irfft.
On Fri, Jan 10, 2020, 10:16 PM Brian Patton 🚀 bjp@google.com wrote:
What is the transpose you're talking about? Why not always do rfft and irfft in xla directly?
On Fri, Jan 10, 2020, 6:50 PM Stephan Hoyer notifications@github.com wrote:
I think it would be totally reasonable to start with the easy wrapper approach based on real(fft(x)).
Conceivably we could even still use the XLA rfft functions for cases where we don't need a transpose (we could write only the transpose rule in terms of the complex FFTs)
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI5LHA6GVPTB2EJUO33Q5ECTVA5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIVRR3A#issuecomment-573249772, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7SDAXN57AQKOVVG7TQ5ECTVANCNFSM4JZERKVA .
To implement backwards gradientd in JAX, we need a JVP rule and a transpose rule.
The JVP rule is easy (because FFTs are already linear) but the transpose rule (the same as the gradient definition in TensorFlow) for rfft can't be defined in terms of the other FFTs.
I see, so we would need to port over this helper: https://github.com/sourcecode369/tensorflow-1/blob/63f319f30302304506e1a535856a68106f396524/tensorflow/python/ops/signal/fft_ops.py#L233
On Fri, Jan 10, 2020, 11:41 PM Stephan Hoyer notifications@github.com wrote:
To implement backwards gradientd in JAX, we need a JVP rule and a transpose rule.
The JVP rule is easy (because FFTs are already linear) but the transpose rule (the same as the gradient definition in TensorFlow) for rfft can't be defined in terms of the other FFTs.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI6OXHVW2MYBF7MWQKTQ5FEY5A5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIVZFKQ#issuecomment-573280938, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7OHQL4BKUKKLZDS2LQ5FEY5ANCNFSM4JZERKVA .
I got transpose rules for rfftn and irfftn working over in https://github.com/google/jax/pull/1657, which is now ready for review!
bump! :)
I implemented 1D and 2D DCT based on a single FFT of the same size here.
@juliuskunze Neat! Are you interested in perhaps sending a PR adding them to JAX, perhaps under the jax.scipy
namespace?
@hawkinsp I added full support for DCT type 2 to JAX in https://github.com/google/jax/pull/7617
Reference implementations in terms of FFT may be found in TF: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/signal/dct_ops.py#L53