GEOS-ESM / NDSL

NOAA NASA Domain Specific Language middleware layer
0 stars 0 forks source link

Debug loop merge failings in orchestration #47

Open FlorianDeconinck opened 3 months ago

FlorianDeconinck commented 3 months ago

The kernel split in orchestration is very high, leading to a kernel fragmentation that is very damaging to CPU (and not great on GPU).

We need to understand deeply why and plan for resolving this issue. Rely on the SPCL crew to develop an understanding.

What we know

The way cartesian linked to DaCe is very heavy in state and the StateMerge transform seems to not perform. Likewise the stencilization of gt4py also means a ton of nested SDFG, also something that seems to impair state & map mege.

The solutions


FlorianDeconinck commented 2 months ago

Phillip Mueller work on his gt4py.next fork

FlorianDeconinck commented 2 months ago

Schedule Tree PR and branch

The current state of the work is that SDFG->stree is implemented. The reverse is not (partial work done). Our job is to evaluate if this makes merging better (including for MW003.1) by trying to write a transform in it. If it's the case, we will then push for finishing the reverse conversion (and potentially helping with it ofc).

Here's the easiest example to show how to get to it and print it's results

import dace
import numpy as np

N: int = 10

@dace.program
def double_loop(
    field_A: dace.float64[N],
    field_B: dace.float64[N],
    result: dace.float64[N],
    weight: float,
):
    for i in dace.map[0:N]:
        field_A[i] = result[i]
    for ii in dace.map[0:N]:
        field_B[ii] = field_A[ii] * weight

if __name__ == "__main__":
    r = np.zeros(N)
    f0 = np.ones(N)
    f1 = np.ones(N)
    stree = double_loop.to_sdfg().as_schedule_tree()
    print(stree.as_string())

This prints in a pseudo-py representation

  map i in [0:10]:
    field_A[i] = tasklet(result[i])
  map ii in [0:10]:
    __tmp1[0] = tasklet(field_A[ii], weight[0])
    field_B[ii] = tasklet(__tmp1[0])
FlorianDeconinck commented 1 month ago

Here's a full example of a simple example of stencils which should merge and doesn't with ndsl orchestration, followed by an hardcoded way to merge those easily under the Schedule Tree paradigm

from gt4py.cartesian.gtscript import (
    computation,
    interval,
    PARALLEL,
)
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField
from ndsl import StencilFactory, orchestrate
import numpy as np
import dace
import dace.sdfg.analysis.schedule_tree.treenodes as dace_stree
from dace.transformation.dataflow import MapFusion
from dace.transformation.interstate import StateFusion

domain = (3, 3, 4)

stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu(
    domain[0], domain[1], domain[2], 0
)

def double_map(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(...):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

class DaCeGT4Py_Bridge:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(obj=self, config=stencil_factory.config.dace_config)
        self.stencil = stcil_fctry.from_dims_halo(
            func=double_map,
            compute_dims=[X_DIM, Y_DIM, Z_DIM],
        )

    def __call__(self, in_field: FloatField, out_field: FloatField):
        self.stencil(in_field, out_field)

if __name__ == "__main__":
    I = np.arange(domain[0] * domain[1] * domain[2], dtype=np.float64).reshape(domain)
    O = np.zeros(domain)

    # # Trigger NDSL orchestration pipeline & grab cached SDFG
    bridge = DaCeGT4Py_Bridge(stcil_fctry)
    bridge(I, O)
    sdfg: dace.sdfg.SDFG = bridge.__sdfg__(I, O).csdfg.sdfg
    sdfg.save("orig.sdfg")
    stree = sdfg.as_schedule_tree()

    #### Classic SDFG transform
    # Strategy: fuse states out of the way, then merge maps

    # State fusion
    # No fusion occurs because of a potential write race since
    # it would be putting both maps under one input
    r = sdfg.apply_transformations_repeated(
        StateFusion,
        print_report=True,
        validate_all=True,
    )
    print(f"Fused {r} states")
    sdfg.save("state_fusion.sdfg")
    # No fusion occurs because maps are in different states
    # (previous failure to merge)
    r = sdfg.apply_transformations_repeated(
        MapFusion,
        print_report=True,
        validate_all=True,
    )
    print(f"Fused {r} maps")
    sdfg.save("map_fusion.sdfg")

    #### Schedule Tree transform
    # We demonstrate here a very basic usage of Schedule Tree to merge
    # maps that have the same range, which should be the first pass
    # we write
    print("Before merge")
    print(stree.as_string())
    first_map: dace_stree.MapScope = stree.children[2]
    second_map: dace_stree.MapScope = stree.children[3]
    if first_map.node.range == second_map.node.range:
        first_map.children.extend(second_map.children)
        first_map.parent.children.remove(second_map)
    print("After merge")
    print(stree.as_string())