Open FlorianDeconinck opened 3 months ago
Below is a small example showing the above problem
from gt4py.cartesian.gtscript import (
computation,
interval,
PARALLEL,
)
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField
from ndsl import StencilFactory, orchestrate
import numpy as np
domain = (3, 3, 4)
stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu(domain[0], domain[1], domain[2], 0)
def small_conditional(
in_field: FloatField,
out_field: FloatField
):
with computation(PARALLEL), interval(...):
if in_field > 5.0 and in_field < 20:
out_field = in_field
class DaCeGT4Py_Bridge:
def __init__(self, stencil_factory: StencilFactory):
orchestrate(
obj=self,
config=stencil_factory.config.dace_config)
self.stencil = stcil_fctry.from_dims_halo(
func=small_conditional,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)
def __call__(self, in_field: FloatField, out_field: FloatField):
self.stencil(in_field, out_field)
if __name__ == "__main__":
I = np.arange(domain[0]*domain[1]*domain[2], dtype=np.float64).reshape(domain)
O = np.zeros(domain)
bridge = DaCeGT4Py_Bridge(stcil_fctry)
bridge(I, O)
print(f"Input:\n{I}\n")
print(f"Output:\n{O}\n")
This generate the following SDFG (under dace:cpu
)
where the tasklet code carries the conditional:
mask_140660478452496_gen_0: dace.bool_
mask_140660478452496_gen_0 = ((gtIN__in_field > dace.float64(5.0)) and (gtIN__in_field < dace.float64(dace.int64(20))))
if mask_140660478452496_gen_0:
gtOUT__out_field = gtIN__in_field
We want this conditional to be extracted before the tasklet. See dace.sdfg.SDFG:add_loop
code to see how to dynamically create an if condition using the guard mechanism
Known native construct that are folded inside the Tasklet instead of described to DaCe:
GT4Py-DaCe uses an "expansion" mechanism. The stencils are dace.library_nodes
while the code is being parsed (orchestration). Then expansion of those nodes is triggered and unrolls the stencil into an SDFG compatible description. It's within this code that the decision to write all code has tasklet is finalized. Refactor might need to impact pre-expansion as well in the DaceIRBuilder
Latest work is in https://github.com/romanc/gt4py/romanc/bridge-on-top-of-cleanups, which is re-based on top of the cleanups in PR https://github.com/GridTools/gt4py/pull/1724 and should only contain changes necessary for the new bridge and no cleanups anymore.
The new bridge is currently pending on the following two DaCe issues:
Working around those two, we still (only sometimes) have an issue with writes (in k) inside a k
-loop (i.e. with computation FORWARD
/ BACKWARD
). To be investigated further at a later point.
The current bridge between
gt4py.cartesian
andDaCe
uses a concept ofexpansion
for stencils. This means that stencils, once they got through the entire GT4Py workflow, are pickled and shoved down a library node. Then theexpansion
mechanism turns those into a single map/tasklet perHorizontal Computation
. This means that a lot of control-flow (mask, while loop, etc.) is hidden within from DaCe, leading to some optimization power being hidden.We need to undo this and describe everything back to DaCe.
Expansion code lives around
src/gt4py/cartesian/gtc/dace/expansion/*.py
in GT4PyParent: https://github.com/GEOS-ESM/SMT-Nebulae/issues/31
NB: This is a requirement for DaCe AD future features to kick in