jipolanco / PencilFFTs.jl

Fast Fourier transforms of MPI-distributed Julia arrays
https://jipolanco.github.io/PencilFFTs.jl/dev/
MIT License
77 stars 7 forks source link

manypencilarrays #52

Closed Lightup1 closed 2 years ago

Lightup1 commented 2 years ago

Is it possible to implement in-place fft to PencilArray? The manypencilArray wrapper add difficulties to me when I try to rewrite my code which originally uses in-place fft on CuArray or Array. Since the manypencilarray<:PencilArray=false and one needs to modify the data with functions like first etc, it will not as easy as shifting Array based code to CuArray based code, which only needs to change all Array type into CuArray type. It will be nice if one can shift Array or CuArray based code by simply change Array or CuArray into PencilArray, which seems inaccessible now if one's codes are using in-place fft.

jipolanco commented 2 years ago

I see. Unfortunately it's not possible to do in-place FFTs directly on PencilArrays.

Here is the problem. When working with serial FFTs on an Array or CuArray, the output size of an FFT is the same as the input size (except in the case of real-to-complex FFTs, but let's forget this for now). This is not always the case for distributed arrays (PencilArrays), since each of these arrays holds the data associated to the local process. Since the partitioning scheme (the Pencil) is not the same between the input and the output, this means that the local input and local output sizes may be different.

A second issue is that a PencilArray includes information on its associated partitioning scheme in its type. In other words, if you do typeof(u) where u::PencilArray, then the result will be different if u is the input or the output of a transform. This is of course not the case for an Array or a CuArray.

The ManyPencilArray type was created precisely to work around this problem. It includes different PencilArrays pointing to the same data, but whose sizes can differ from one partitioning scheme to the other.

Maybe there is another way around this issue, but for now I don't see a better solution. What I can suggest you do in your code is defining a couple of functions like:

transform_input(A::ManyPencilArray) = first(A)
transform_input(A::AbstractArray) = A   # this accounts for all the other cases (Array, CuArray, ...)

transform_output(A::ManyPencilArray) = last(A)
transform_output(A::AbstractArray) = A
Lightup1 commented 2 years ago

Okay, I’ll give a try. Thanks for the detailed explanation!