google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.75k forks source link

Add DCT ops in fftpack #1839

Open brianwa84 opened 4 years ago

brianwa84 commented 4 years ago

Reference implementations in terms of FFT may be found in TF: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/signal/dct_ops.py#L53

shoyer commented 4 years ago

We would probably need rfft and irfft first. I made a little bit of progress in https://github.com/google/jax/pull/1657, but got stuck on the transpose rules. It looks those are indeed somewhat non-trivial: https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/signal/fft_ops.py#L218

brianwa84 commented 4 years ago

If you assume an even-sized dimension, here are impls for i/rfft:

def rfft(x): x = tf.convert_to_tensor(x, dtype=tf.float32) return fft(tf.cast(x, tf.complex64))[..., :tf.shape(x)[-1]//2+1] def irfft(x,n=None): x = tf.concat([x, tf.math.conj(x[..., ::-1][..., 1:-1])], axis=-1) return tf.math.real(ifft(x)) Brian Patton | Software Engineer | bjp@google.com

On Tue, Dec 10, 2019 at 6:37 PM Stephan Hoyer notifications@github.com wrote:

We would probably need rfft and irfft first. I made a little bit of progress in #1657 https://github.com/google/jax/pull/1657, but got stuck on the transpose rules. It looks those are indeed somewhat non-trivial:

https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/signal/fft_ops.py#L218

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI2MLSO23SH7EKAUDALQYAR4LA5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGRKZUA#issuecomment-564309200, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSIYPIHUEAYLNT5735LDQYAR4LANCNFSM4JZERKVA .

shoyer commented 4 years ago

I was thinking that there is no point in adding rfft without the efficiency gain, but it's a fair point that some users might just want them for the API convenience.

joglekara commented 4 years ago

Hmm, do we have any further thoughts on this?

Should we target an efficient rfft implementation or just implement a np.real() call around the complex fft routines?

Would we be adding significant inertia against working on the former by implementing the latter?

shoyer commented 4 years ago

I think it would be totally reasonable to start with the easy wrapper approach based on real(fft(x)).

Conceivably we could even still use the XLA rfft functions for cases where we don't need a transpose (we could write only the transpose rule in terms of the complex FFTs). Actually arguably this could have better performance than the gradient of rfft in TensorFlow, which seems to fall back to using matrix-multiplication.

brianwa84 commented 4 years ago

What is the transpose you're talking about? Why not always do rfft and irfft in xla directly?

On Fri, Jan 10, 2020, 6:50 PM Stephan Hoyer notifications@github.com wrote:

I think it would be totally reasonable to start with the easy wrapper approach based on real(fft(x)).

Conceivably we could even still use the XLA rfft functions for cases where we don't need a transpose (we could write only the transpose rule in terms of the complex FFTs)

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI5LHA6GVPTB2EJUO33Q5ECTVA5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIVRR3A#issuecomment-573249772, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7SDAXN57AQKOVVG7TQ5ECTVANCNFSM4JZERKVA .

brianwa84 commented 4 years ago

Xla fft takes an fft_type which can be any of fft, ifft, rfft, or irfft.

On Fri, Jan 10, 2020, 10:16 PM Brian Patton 🚀 bjp@google.com wrote:

What is the transpose you're talking about? Why not always do rfft and irfft in xla directly?

On Fri, Jan 10, 2020, 6:50 PM Stephan Hoyer notifications@github.com wrote:

I think it would be totally reasonable to start with the easy wrapper approach based on real(fft(x)).

Conceivably we could even still use the XLA rfft functions for cases where we don't need a transpose (we could write only the transpose rule in terms of the complex FFTs)

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI5LHA6GVPTB2EJUO33Q5ECTVA5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIVRR3A#issuecomment-573249772, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7SDAXN57AQKOVVG7TQ5ECTVANCNFSM4JZERKVA .

shoyer commented 4 years ago

To implement backwards gradientd in JAX, we need a JVP rule and a transpose rule.

The JVP rule is easy (because FFTs are already linear) but the transpose rule (the same as the gradient definition in TensorFlow) for rfft can't be defined in terms of the other FFTs.

brianwa84 commented 4 years ago

I see, so we would need to port over this helper: https://github.com/sourcecode369/tensorflow-1/blob/63f319f30302304506e1a535856a68106f396524/tensorflow/python/ops/signal/fft_ops.py#L233

On Fri, Jan 10, 2020, 11:41 PM Stephan Hoyer notifications@github.com wrote:

To implement backwards gradientd in JAX, we need a JVP rule and a transpose rule.

The JVP rule is easy (because FFTs are already linear) but the transpose rule (the same as the gradient definition in TensorFlow) for rfft can't be defined in terms of the other FFTs.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1839?email_source=notifications&email_token=AFJFSI6OXHVW2MYBF7MWQKTQ5FEY5A5CNFSM4JZERKVKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEIVZFKQ#issuecomment-573280938, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7OHQL4BKUKKLZDS2LQ5FEY5ANCNFSM4JZERKVA .

shoyer commented 4 years ago

I got transpose rules for rfftn and irfftn working over in https://github.com/google/jax/pull/1657, which is now ready for review!

nmheim commented 4 years ago

bump! :)

juliuskunze commented 3 years ago

I implemented 1D and 2D DCT based on a single FFT of the same size here.

hawkinsp commented 3 years ago

@juliuskunze Neat! Are you interested in perhaps sending a PR adding them to JAX, perhaps under the jax.scipy namespace?

juliuskunze commented 3 years ago

@hawkinsp I added full support for DCT type 2 to JAX in https://github.com/google/jax/pull/7617