GridTools / gt4py

Python library for generating high-performance implementations of stencil kernels for weather and climate modeling from a domain-specific language (DSL).
https://GridTools.github.io/gt4py
BSD 3-Clause "New" or "Revised" License
112 stars 49 forks source link

Slicing sparse fields #823

Open havogt opened 2 years ago

havogt commented 2 years ago

In field view we currently only support sparse fields in reductions. If we want to implement codes like the following, we need a way to slice fields in the sparse dimension.

dusk example

@stencil
def foo(
    v: Field[Vertex],
    sparse: Field[Edge > Cell > Vertex],
):
    res: Field[Edge]
    with domain #...
        res = sum_over(
            Edge > Cell > Vertex,
            v * sparse
            weights=[1.0, 1.0, 0.0, 0.0],
        )
    #...

we would like to write it as

@field_operator
def foo(
    v: Field[[VertexDim], float],
    sparse: Field[[EdgeDim, E2C2VDim], float],
) ...:
    res = v(E2C2V[0]) * sparseXXX + v(E2C2V[1]) * sparseXXX
                              ^^^
    # ...

What is the best syntax for slicing, I see 3 alternatives

  1. sparse[E2C2VDim[X]]: slice notation with a named index (note it's a dimension not an offset)
  2. sparse[X]: slice notation with implicit dimension (probably not the solution)
  3. sparse(X): shift notation, compatible with Iterator IR idea of shifting

I prefer

havogt commented 2 years ago

If we go with 1, probably the stencil should be written as

    v(E2C2V)[E2C2VDim[0]] * sparse[E2C2VDim[0]] + v(E2C2V)[E2C2VDim[1]] * sparse[E2C2VDim[1]]

otherwise the use of () vs [] is super confusing, but on the other hand it's very verbose.

tehrengruber commented 2 years ago

There is one conceptual problem with respect to the lowering that seems rather dangerous to me. Namely after you selected an element in the neighbor/sparse dimension the resulting expression can accidentally be shifted again, leading to unexpected behavior and potentially even an out-of-bounds access. Consider a sparse field vertex_orientation: Field[[Vertex, V2V], dtype] and the following snippet:

result_first_vertex:  Field[Vertex] = vertex_orientation[V2VDim[1]]
wrong_result_1 = neighbor_sum(result_first_vertex*vertex_orientation, axis=V2V)
wrong_result_2 = broadcast(result_first_vertex, (Vertex, V2V))[V2VDim[2]]

The dangerous part is that the result_first_vertex poisons every expression it is used in; potentially dozens of calls away or even across field operator boundaries. With respect to the semantics of the frontend: If you promote/broadcast result_first_vertex , e.g. by multiplying by a another sparse field, the value in the reintroduced sparse dimension should still be equal to the previously selected value, i.e. at V2VDim[1]. However there is no way to express that all partial shifts from now on should be ignored (which is what we want). For wrong_result_1 the iterator backend will happily just continue shifting in the sparse dimension and the frontend will emit a shift for wrong_result_2. I am rather hesitant to implement the feature without solving this issue because the across stencil boundary part is a nightmare with respect to debugging & verification.

havogt commented 2 years ago

I don't fully understand the problem.

e.g. by multiplying by a another sparse field, the value in the reintroduced sparse dimension should still be equal to the previously selected value, i.e. at V2VDim[1]

I agree that this is the expected behavior, but I don't see why you wouldn't get this behavior. broadcasting (at least currently) is a pure type checking feature and a noop in lowering.

However there is no way to express that all partial shifts from now on should be ignored (which is what we want).

I don't understand this statement, you still want to be able to shift the field. After slicing the sparse field it should behave exactly like a non sparse field, i.e. you want to apply shifts as usual.

Though I see a potential problem for the second point. Let me explain what I think the lowering should do. [Dim[N]] just translates to a shift(N), the additional Dim is just for checking in the frontend.

For wrong_result_2 we have

result_first_vertex = shift(1)(vertex_orientation)
wrong_result_2 = shift(2)(result_first_vertex)

In the current interpretation of partially shifted fields, we have the following chain of offsets V2V, 1, 2, i.e. we have an additional offset 2 that should be ignored. The problem is, at iterator IR level we don't now the V2V, we see only 1, 2 and don't know if they can be ignored (for embedded execution it should be ok, but for for the C++ backend it is less clear).

havogt commented 2 years ago

I discussed this question with @egparedes.

We believe the problem mentioned above should be resolved at the field view level (before lowering): If there is a broadcast we can track that the field is constant in that dimension and throw away any slicing in that dimension.

havogt commented 2 years ago

We continued the discussion of which syntax to use and propose the following:

The syntax constructing an index in a dimension is Dim(int) (instead of Dim[int], see later). The example above would then look like

return v(E2C2V[0]) * sparse[E2C2VDim(0)] + v(E2C2V[1]) * sparse[E2C2VDim(1)]

or

tmp = v(E2C2V) * sparse
return tmp[E2C2VDim(0)] + tmp[E2C2VDim(1)]

Now this is just a reduction over a subset of the elements of a dimension, therefore we propose enhance neighbor_sum

return neighbor_sum(v(E2C2V)*sparse, over=E2C2VDim[0:2])

(axis renamed to over, no strong opinion on that name) with the semantics of selecting a range of indices.

Note the use of [] vs ()`:

E2C2VDim[0:2] == [E2C2VDim(0), E2C2VDim(1)]
E2C2VDim[0] == [E2C2VDim(0)]
havogt commented 2 years ago

@tehrengruber pointed out that

We believe the problem mentioned above should be resolved at the field view level (before lowering): If there is a broadcast we can track that the field is constant in that dimension and throw away any slicing in that dimension.

is not straight-forward, maybe impossible, because across field_operator boundaries we lose the information if something was broadcasted.

We discussed the following approaches which are all more or less workarounds:

tehrengruber commented 2 years ago

Another (third) approach (that also feels like a workaround) is to introduce a new built-in on Iterator IR level that ignores all following partial shifts. This built-in could be emitted by the broadcast method whenever there is a broadcast in a local dimension.