openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.62k stars 409 forks source link

Support SPMD sharding of `fft` ops. #24

Open hawkinsp opened 1 year ago

hawkinsp commented 1 year ago

In https://github.com/google/jax/issues/13081 we found that XLA doesn't support SPMD sharding of fast-fourier transform ops. It should!

cheshire commented 1 year ago

Tracked internally in b/263023739

jon-chuang commented 1 year ago
1. Dim 1 FFT: To simplify the problem, consider when the number of shards and FFT size are powers of 2, i.e. num_shards=2^k. - map-reduce style: Do a size `N/num_shards` FFT on each device. Subsequently, there is a shuffle per stage. 2. For dim > 1 FFT, The problem actually becomes simpler: if N % num_shards == 0, one can perform the FFTs along the row dim, and then do a shuffle, performing the FFTs in parallel along the column dim. (see: http://dsp-book.narod.ru/FFTBB/0270_PDF_C23.pdf) So we would have to restrict to certain choices of $N$ and `num_shards`

Oh wait, this is parallelizing single FFT... smpd sharding (pmap) should be a lot simpler.

jon-chuang commented 1 year ago

Wait, there seems to be previous work on SPMD for FFT: https://github.com/openxla/xla/commit/180dbd69d2cf1885b3e3f18d71a9cc9669923f82. Seems it only supports rank >=3: https://github.com/openxla/xla/blob/7b562aa0d9bc54d01580f9f1a619e2e3f28df1f4/xla/service/spmd/fft_handler.cc#L350

And it does indeed do a sharded parallel FFT...

jon-chuang commented 1 year ago

Could you clarify what is meant by SPMD in the case of FFTs @hawkinsp @cheshire ?

Does this mean running multiple FFTs on tensors with rank >= 2, along a batch dimension? Or does it mean parallelising a single FFT across the FFT's ranks?