Open hawkinsp opened 2 years ago
Tracked internally in b/263023739
Oh wait, this is parallelizing single FFT... smpd sharding (pmap) should be a lot simpler.
Wait, there seems to be previous work on SPMD for FFT: https://github.com/openxla/xla/commit/180dbd69d2cf1885b3e3f18d71a9cc9669923f82. Seems it only supports rank >=3: https://github.com/openxla/xla/blob/7b562aa0d9bc54d01580f9f1a619e2e3f28df1f4/xla/service/spmd/fft_handler.cc#L350
And it does indeed do a sharded parallel FFT...
Could you clarify what is meant by SPMD in the case of FFTs @hawkinsp @cheshire ?
Does this mean running multiple FFTs on tensors with rank >= 2, along a batch dimension? Or does it mean parallelising a single FFT across the FFT's ranks?
In https://github.com/google/jax/issues/13081 we found that XLA doesn't support SPMD sharding of fast-fourier transform ops. It should!