GEOS-ESM / NDSL

NOAA NASA Domain Specific Language middleware layer
0 stars 0 forks source link

K-loop aggresive fusion (dace orchestrated) #6

Open FlorianDeconinck opened 8 months ago

FlorianDeconinck commented 8 months ago

Cache miss occurs in the code, even the orchestrated, because we have IJK loops around all stencils instead of having K-loop shared by the stencils. Even after applying "map fusion", it's nowhere near enough to match the K-block/loop of Fortran.

For instance the entire code of D_SW is within a single K-loop in Fortran.

Using Tal's new stree or direct SDFG manipulation, we shall implement an aggressive transform on dace to move the K-loop as far as possible, including by introducing if to block out indices that should be computed.

Another code, called partial expansion which work on SDFG alone and develop by Linus, exists in GT4Py.cartesian as a branch.

stree branch: https://github.com/spcl/dace/tree/stree-to-sdfg


FlorianDeconinck commented 1 month ago

Example showing the state of things in SDFG pass and a trivial application of the stree

from gt4py.cartesian.gtscript import (
    computation,
    interval,
    PARALLEL,
)
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField
from ndsl import StencilFactory, orchestrate
import numpy as np
import dace
import dace.sdfg.analysis.schedule_tree.treenodes as dace_stree
from dace.transformation.dataflow import MapFusion
from dace.transformation.interstate import StateFusion

domain = (3, 3, 4)

stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu(
    domain[0], domain[1], domain[2], 0
)

def double_map(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(...):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

class DaCeGT4Py_Bridge:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(obj=self, config=stencil_factory.config.dace_config)
        self.stencil = stcil_fctry.from_dims_halo(
            func=double_map,
            compute_dims=[X_DIM, Y_DIM, Z_DIM],
        )

    def __call__(self, in_field: FloatField, out_field: FloatField):
        self.stencil(in_field, out_field)

if __name__ == "__main__":
    I = np.arange(domain[0] * domain[1] * domain[2], dtype=np.float64).reshape(domain)
    O = np.zeros(domain)

    # # Trigger NDSL orchestration pipeline & grab cached SDFG
    bridge = DaCeGT4Py_Bridge(stcil_fctry)
    bridge(I, O)
    sdfg: dace.sdfg.SDFG = bridge.__sdfg__(I, O).csdfg.sdfg
    sdfg.save("orig.sdfg")
    stree = sdfg.as_schedule_tree()

    #### Classic SDFG transform
    # Strategy: fuse states out of the way, then merge maps

    # State fusion
    # No fusion occurs because of a potential write race since
    # it would be putting both maps under one input
    r = sdfg.apply_transformations_repeated(
        StateFusion,
        print_report=True,
        validate_all=True,
    )
    print(f"Fused {r} states")
    sdfg.save("state_fusion.sdfg")
    # No fusion occurs because maps are in different states
    # (previous failure to merge)
    r = sdfg.apply_transformations_repeated(
        MapFusion,
        print_report=True,
        validate_all=True,
    )
    print(f"Fused {r} maps")
    sdfg.save("map_fusion.sdfg")

    #### Schedule Tree transform
    # We demonstrate here a very basic usage of Schedule Tree to merge
    # maps that have the same range, which should be the first pass
    # we write
    print("Before merge")
    print(stree.as_string())
    first_map: dace_stree.MapScope = stree.children[2]
    second_map: dace_stree.MapScope = stree.children[3]
    if first_map.node.range == second_map.node.range:
        first_map.children.extend(second_map.children)
        first_map.parent.children.remove(second_map)
    print("After merge")
    print(stree.as_string())
FlorianDeconinck commented 5 days ago

Per @romanc - state of exploration

Schedule Tree

In DaCe, a Schedule Tree is a representation designed for high-level optimizations. Schedule Trees were designed to help with optimizations like map or state fusion, which are cumbersome to achieve as passes on SDFGs.

Motivation

While there are map & state fusion passes for SDFGs, they under-perform our expectations. Schedule Trees were thus drafted as an alternative representation to allow such high-level optimizations. Specifically, for NDSL, we are interested in two high-level optimizations

  1. Flip the iteration order from I-J-K to K-I-J, pushing the K-loop as far out as possible to gain cache locality on CPU. Note that this transformation should be optional to allow optimal iteration order based on the target hardware (CPU or GPU).
  2. Map merge on SDFGs currently fails for the most basic examples. Especially on GPUs, we'd like to merge as many maps as possible to cut down on the (currently excessive) amount of kernel launches.
  3. If possible, look at force-merging K-loops/maps inserting if statements (as guards) as needed (over.computation) to further cut down on kernel launches on the GPU.

NDSL issue: https://github.com/GEOS-ESM/NDSL/issues/6

Current state

Schedule Trees can be generated from SDFGs. Rudimentary tools (e.g. printing a Schedule Tree) are provided to allow prototyping and (visual) verification of map merge and loop reordering.

The back transformation from Schedule Tree to SDFG currently missing. There is a work in progress branch https://github.com/spcl/dace/pull/1466 containing half the code needed to back from Schedule Trees to SDFGs.

Evaluation

Evaluation was done on a couple of simple examples. Code for these experiments lives in that branch.

Double map - trivial merge

def double_map(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(...):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

Plain Schedule tree representation

out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
  map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
    map __k in [0:4]:
      out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
  map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
    map __k in [0:4]:
      out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])

K-loop re-order and map merge

out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
      out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])

Forced k-merge (over-computation)

def double_map_with_different_intervals(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(1, None):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

Forced k-merge with (note how if statements are pushed down through maps enabling map merges in I & J)

in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
        out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
      if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
        out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])

TODO: We should skip the second if because it spans the full range.

Forced k-merge (preserve order of computation blocks)

def mergeable_preserve_order(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(1, None):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

    with computation(PARALLEL), interval(1, None):
        out_field = in_field * 4

The order of computation blocks is preserved and only consecutive blocks are merged; even in case of force-merging in K.

out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
        if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
          out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
        if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
          out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
      if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
        out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])

Note the extra if statements. This could (maybe should) be cleaned up in a post-poc implementation.

Simple data dependency analysis

def not_mergeable_k_dependency(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(1, None):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

    with computation(PARALLEL), interval(1, None):
        in_field = out_field

Force-merged the first two, but not the third because of data dependencies (read after write). This can be tuned in a post-poc implementation.

  out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
  in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
  map __k in [0:4]:
    map __tile_j, __tile_i in [0:3:8, 0:3:8]:
      map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
        if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
          out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
        if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
          out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
  map __k in [1:4]:
    map __tile_j, __tile_i in [0:3:8, 0:3:8]:
      map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
        in_field[__i, __j, __k] = tasklet(out_field[__i, __j, __k])

TODO

Loops and maps are never merged

def loop_and_map(in_field: FloatField, out_field: FloatField):
    with computation(FORWARD), interval(...):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3

Sequential blocks (i.e. FORWARD, BACKWARD) are never merged with PARALLEL computation blocks.

out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
for __k = 0; (__k < (4 + 0)); __k = (__k + 1):
  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __k in [0:4]:
  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])

TODO: Even more aggressive k-merge as above

To be fleshed out

Questions (for Tal)

Can you specify the GPU launch parameters in the Schedule Tree? How does that translate?

When doing over-computation (in the case of map-merge), do we need to flag memlets as dynamic? Is it harmful if we lie to dace and hint at more read/writes in that case?

Can DaCe still re-use arrays in case we do very aggressive map merging? (Especially in the restricted space of GPU (v)RAM)

When using a temporary field, the "array definitions" are missing. How is that different for the other examples?

def tmp_field(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(...):
        tmp = in_field

    with computation(PARALLEL), interval(...):
        out_field = tmp * 3

Plain Schedule Tree representation

  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      map __k in [0:4]:
        tmp[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
  map __tile_j, __tile_i in [0:3:8, 0:3:8]:
    map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
      map __k in [0:4]:
        out_field[__i, __j, __k] = tasklet(tmp[__i, __j, __k])

Horizontal regions show as "dynamic memlets". We currently exclude map merges with dynamic memlets in it. Is that a wise idea?

def horizontal_regions(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(...):
        out_field = in_field * 2

    with computation(PARALLEL), interval(...):
        with horizontal(region[:, :-1]):
            out_field = in_field
        with horizontal(region[:-1, :]):
            out_field = in_field
  in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
  out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
  map __k in [0:4]:
    map __tile_j, __tile_i in [0:3:8, 0:3:8]:
      map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
        if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
          out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
    map __tile_j, __tile_i in [0:3:8, 0:3:8]:
      map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
        if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
          out_field(dyn) [__i, __j, __k] = tasklet(in_field(dyn) [__i, __j, __k])

Conclusion

The Schedule Tree representation keeps its promise. The above evaluation shows that map fusion and loop reordering are possible in that representation. The proof of concept further shows that such optimization passes can be implemented elegantly due to the tree structure.

On the downside, one has to note the additional overhead of creating and maintaining the transformations between SDFGs and Schedule Trees. While it is theoretically possible to generate code directly from Schedule Trees, we would loose the possibility to run (existing) optimization passes after loop transformations and map fusion, which could unlock things like vectorization. In addition, current AI research is based on the SDFG representation. The conclusion is thus to live with the maintenance and transformation costs between Schedule Trees and SDFGs.

romanc commented 5 days ago

Forced k-merge (over-computation)

TODO: We should skip the second if because it spans the full range.

This was a simple check to if original and merged ranges are equal. For the case of the simple example

def double_map_with_different_intervals(in_field: FloatField, out_field: FloatField):
    with computation(PARALLEL), interval(1, None):
        out_field = in_field

    with computation(PARALLEL), interval(...):
        out_field = in_field * 3
  in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
  out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
  map __k in [0:4]:
    map __tile_j, __tile_i in [0:3:8, 0:3:8]:
      map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
        if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
          out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
        out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])

we now only see one extra/inserted if statement (as expected).