GEOS-ESM / NDSL

NOAA NASA Domain Specific Language middleware layer
0 stars 0 forks source link

Refactor of GT4Py-DaCe bridge to expose all control-flow to Dace #53

Open FlorianDeconinck opened 3 months ago

FlorianDeconinck commented 3 months ago

The current bridge between gt4py.cartesian and DaCe uses a concept of expansion for stencils. This means that stencils, once they got through the entire GT4Py workflow, are pickled and shoved down a library node. Then the expansion mechanism turns those into a single map/tasklet per Horizontal Computation. This means that a lot of control-flow (mask, while loop, etc.) is hidden within from DaCe, leading to some optimization power being hidden.

We need to undo this and describe everything back to DaCe.

Expansion code lives around src/gt4py/cartesian/gtc/dace/expansion/*.py in GT4Py

Parent: https://github.com/GEOS-ESM/SMT-Nebulae/issues/31

NB: This is a requirement for DaCe AD future features to kick in


FlorianDeconinck commented 3 months ago

Below is a small example showing the above problem

from gt4py.cartesian.gtscript import (
    computation,
    interval,
    PARALLEL,
)
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField
from ndsl import StencilFactory, orchestrate
import numpy as np

domain = (3, 3, 4)

stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu(domain[0], domain[1], domain[2], 0)

def small_conditional(
    in_field: FloatField,
    out_field: FloatField
):
    with computation(PARALLEL), interval(...):
        if in_field > 5.0 and in_field < 20:
            out_field = in_field

class DaCeGT4Py_Bridge:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config)
        self.stencil = stcil_fctry.from_dims_halo(
            func=small_conditional,
            compute_dims=[X_DIM, Y_DIM, Z_DIM],
        )

    def __call__(self, in_field: FloatField, out_field: FloatField):
        self.stencil(in_field, out_field)

if __name__ == "__main__":
    I = np.arange(domain[0]*domain[1]*domain[2], dtype=np.float64).reshape(domain)
    O = np.zeros(domain)

    bridge = DaCeGT4Py_Bridge(stcil_fctry)
    bridge(I, O)

    print(f"Input:\n{I}\n")
    print(f"Output:\n{O}\n")

This generate the following SDFG (under dace:cpu)

Image

where the tasklet code carries the conditional:

mask_140660478452496_gen_0: dace.bool_
mask_140660478452496_gen_0 = ((gtIN__in_field > dace.float64(5.0)) and (gtIN__in_field < dace.float64(dace.int64(20))))
if mask_140660478452496_gen_0:
    gtOUT__out_field = gtIN__in_field

We want this conditional to be extracted before the tasklet. See dace.sdfg.SDFG:add_loop code to see how to dynamically create an if condition using the guard mechanism

FlorianDeconinck commented 3 months ago

Known native construct that are folded inside the Tasklet instead of described to DaCe:

GT4Py-DaCe uses an "expansion" mechanism. The stencils are dace.library_nodes while the code is being parsed (orchestration). Then expansion of those nodes is triggered and unrolls the stencil into an SDFG compatible description. It's within this code that the decision to write all code has tasklet is finalized. Refactor might need to impact pre-expansion as well in the DaceIRBuilder

romanc commented 1 day ago

Latest work is in https://github.com/romanc/gt4py/romanc/bridge-on-top-of-cleanups, which is re-based on top of the cleanups in PR https://github.com/GridTools/gt4py/pull/1724 and should only contain changes necessary for the new bridge and no cleanups anymore.

The new bridge is currently pending on the following two DaCe issues:

Working around those two, we still (only sometimes) have an issue with writes (in k) inside a k-loop (i.e. with computation FORWARD / BACKWARD). To be investigated further at a later point.