data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
211 stars 44 forks source link

RFC: Explicit forbid any axes being repeated in the `axes` of the multi-dimensional FFT APIs? #747

Closed leofang closed 7 months ago

leofang commented 7 months ago

When working on #746 I came to realize one NumPy behavior we've overlooked so far. In a few NumPy APIs, there's an explicit mention about (exact text varies)

Repeated indices in axes means that the inverse transform over that axis is performed multiple times.

including fft2, ifft2, fftn, ifftn, and irfftn. I would propose to add a clause to explicitly forbid axes containing repeating entries, for two reasons

  1. Semantics: It is unclear to me what would the semantics be, say with fftn(..., axes=(2,1,0,1,2)). Is it a 3D FFT (2,1,0) followed by 2D FFT (1,2), or a 2D FFT (2,1) followed by 3D FFT (0,1,2), or something else? From the perspective of a low-level numerical library, it also makes little sense that we offer accelerated routines for this awkward case.
  2. Consistency: AFAIK we've never talked about what happens in such cases, whenever an array API accepts an axes argument. I've been having a mental model that a valid axes must contain non-repeating entries. If this anticipation is correct, I would like to call it out, in particular for FFT where there exists a different precedence.

If people dislike an explicit clause, we should at least add a note to state this is implementation defined.

rgommers commented 7 months ago

If people dislike an explicit clause, we should at least add a note to state this is implementation defined.

I'd go for this option, because it doesn't really matter which one we choose here for the purposes of the standard, and I don't want to have to think about changing NumPy here. It seems like a very niche thing to do, but presumably it's used in the wild somewhere.

leofang commented 7 months ago

OK, I'll piggyback the note on #746.

This is an implementation detail leaked into the NumPy spec. This works for NumPy because pocketfft and its predecessors only handle 1D FFT, so what NumPy does for N-D FFT is to loop over axes, and by the time we enter the loop whether the axis entries are repeated is irrelevant, NumPy will get the job done right.

But it is not the case for accelerated libraries using cuFFT, FFTW, or MKL. During the pre-processing stage, we'll need to make sense of the provided axes to decide if we need to create a 1D/2D/3D/... plan. It'd be hard to get it right.