jipolanco / PencilArrays.jl

Distributed Julia arrays using the MPI protocol
https://jipolanco.github.io/PencilArrays.jl/dev/
MIT License
60 stars 8 forks source link

Distributing lower-dimensional datasets #54

Open jipolanco opened 2 years ago

jipolanco commented 2 years ago

Following a discussion with @glwagner, we have concluded that it would make sense to generalise localgrid to support partitioning arrays of any dimension up to N, where N is the dimensionality of the "geometry" (determined when constructing a Pencil).

As a reminder, localgrid allows to "distribute" a set of N 1D arrays according to a given Pencil configuration:

using MPI
using PencilArrays

MPI.Init()
comm = MPI.COMM_WORLD

# Define global rectilinear grid
# (this can also represent any other quantity that varies along one direction only)
Nx_global, Ny_global, Nz_global = 65, 17, 21
xs_global = range(0, 1; length = Nx_global)
ys_global = range(-1, 1; length = Ny_global)
zs_global = range(0, 2; length = Nz_global)

pen = Pencil((Nx_global, Ny_global, Nz_global), comm)  # here N = 3

grid = localgrid(pen, (xs_global, ys_global, zs_global))

# The result can be used for example to broadcast over the local part of the "grid":
u = PencilArray{Float64}(undef, pen)
@. u = grid.x + 2 * grid.y + grid.z^2

# Note that `grid` can also be indexed:
x, y, z = grid[2, 3, 4]

Desirable features

Some things which would be good to have are:

  1. Support distributing M-dimensional quantities for any M <= N. For instance, it's common to have 2D fields (e.g. boundaries) embedded in 3D geometries.
  2. In particular, support the case M = N: a single process holds global data to be distributed ("scattered") across processes.
  3. Full support for broadcasting in all cases.

Some design thoughts

A possible interface

We could introduce a new partition function which would take care of all possible use cases.

I see two possible modes for partition:

  1. All processes have the global data, and partition just returns a "broadcastable" view to the local part of the data. This is similar to what localgrid already does (but for M = 1 only). This mode makes more sense for M < N (in which the global data is generally lightweight), but the case M = N can also be supported. This operation is easy to implement (there are no communications, all operations are local), and the main question is what would be a good underlying structure for holding the case M < N. The most natural thing would be to extend the PencilArray type (see below).

    Implementation detail: should this mode create views to the original data, or copy it? For the case M = N, making copies would probably be better, as this would allow Julia's GC to discard the original data.

  2. A single "root" process has the global data, and partition distributes the data to all other processes (using MPI.Scatterv! or something equivalent). For N-dimensional arrays, this is the inverse of what is already done by gather.

In both cases, one convenient way of supporting the case M < N is to allow PencilArrays with singleton dimensions, which would then naturally work with broadcasting. As noted by @glwagner, this is what is already done in Oceananigans, and also in Dedalus (after looking at their PRR paper, I guess this corresponds to their "constant" flag?).

Then, for mode 1, we could define

partition(p::Pencil{N}, u_global::AbstractArray{T, N})

and allow singleton dimensions in u_global. For example, for a 2D field u(x, z) in 3D, one would have size(u_global) = (Nx_global, 1, Nz_global).

Alternatively, we can define something like

partition(p::Pencil{N}, u_global::AbstractArray{T, M}; dims::Dims{M})

with size(u_global) = (Nx_global, Nz_global) and dims = (1, 3) in this example.

For mode 2, we could add an extra root parameter similarly to what is already done in gather:

partition(p::Pencil{N}, u_global::Union{Nothing, AbstractArray{T, N}}, root::Integer)

where the u_global parameter would be ignored in MPI processes different than root (it can therefore be nothing in those processes).

Supporting the case M < N

As mentioned above, a convenient way to support distributing lower-dimensional data is to allow singleton dimensions in the PencilArray type. This would require some minor changes but should be quite easy to do.

Alternatively, we could allow M-dimensional PencilArrays, avoiding singleton dimensions. In this case, the PencilArray type should include information allowing to map from array dimensions to physical dimensions -- e.g. the dims = (1, 3) tuple in the example above. Ideally this information should be static. I think this alternative makes more sense if one wants to actually index the array. Concretely, we could support indexing both as u[i, k] and u[i, j, k] (the j is just discarded in this case) with optimal efficiency in both cases.

Changes to localgrid

Right now we define an internal RectilinearGridComponent type to to describe single grid components such as grid.x. This type would no longer be needed if we reimplement localgrid on top of partition, such as the returned object wraps a set of N PencilArrays with singleton dimensions.

glwagner commented 2 years ago

As noted by @glwagner, this is what is already done in Oceananigans, and also in Dedalus (after looking at their PRR paper, I guess this corresponds to their "constant" flag?).

The key point I think is that we often use "broadcastable" dimensions when working with reduced dimension objects (cause broadcasting is pretty convenient, we want it to "just work").

So, a utility partition that can distribute broadcastable objects (ie partitioning along a singleton dimension => copying to all processes) might be generically useful to lots of software that's generalizing from single-process implementation to a distributed implementation via PencilArrays... perhaps.

@kburns and @navidcy might have more thoughts on that!

kburns commented 2 years ago

Yeah in Dedalus we just have singleton dimensions rather than truly lower-dimensional datasets, and scatter when necessary for broadcasting rather than keeping data duplicates on the empty hyperplanes. We usually have so many transposes in the global spectral algorithm that I wasn't too worried about the scatter cost, and then we didn't have to worry about local duplicates becoming desynchronized, etc.