Open FlorianDeconinck opened 8 months ago
Example showing the state of things in SDFG
pass and a trivial application of the stree
from gt4py.cartesian.gtscript import (
computation,
interval,
PARALLEL,
)
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField
from ndsl import StencilFactory, orchestrate
import numpy as np
import dace
import dace.sdfg.analysis.schedule_tree.treenodes as dace_stree
from dace.transformation.dataflow import MapFusion
from dace.transformation.interstate import StateFusion
domain = (3, 3, 4)
stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu(
domain[0], domain[1], domain[2], 0
)
def double_map(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(...):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
class DaCeGT4Py_Bridge:
def __init__(self, stencil_factory: StencilFactory):
orchestrate(obj=self, config=stencil_factory.config.dace_config)
self.stencil = stcil_fctry.from_dims_halo(
func=double_map,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)
def __call__(self, in_field: FloatField, out_field: FloatField):
self.stencil(in_field, out_field)
if __name__ == "__main__":
I = np.arange(domain[0] * domain[1] * domain[2], dtype=np.float64).reshape(domain)
O = np.zeros(domain)
# # Trigger NDSL orchestration pipeline & grab cached SDFG
bridge = DaCeGT4Py_Bridge(stcil_fctry)
bridge(I, O)
sdfg: dace.sdfg.SDFG = bridge.__sdfg__(I, O).csdfg.sdfg
sdfg.save("orig.sdfg")
stree = sdfg.as_schedule_tree()
#### Classic SDFG transform
# Strategy: fuse states out of the way, then merge maps
# State fusion
# No fusion occurs because of a potential write race since
# it would be putting both maps under one input
r = sdfg.apply_transformations_repeated(
StateFusion,
print_report=True,
validate_all=True,
)
print(f"Fused {r} states")
sdfg.save("state_fusion.sdfg")
# No fusion occurs because maps are in different states
# (previous failure to merge)
r = sdfg.apply_transformations_repeated(
MapFusion,
print_report=True,
validate_all=True,
)
print(f"Fused {r} maps")
sdfg.save("map_fusion.sdfg")
#### Schedule Tree transform
# We demonstrate here a very basic usage of Schedule Tree to merge
# maps that have the same range, which should be the first pass
# we write
print("Before merge")
print(stree.as_string())
first_map: dace_stree.MapScope = stree.children[2]
second_map: dace_stree.MapScope = stree.children[3]
if first_map.node.range == second_map.node.range:
first_map.children.extend(second_map.children)
first_map.parent.children.remove(second_map)
print("After merge")
print(stree.as_string())
Per @romanc - state of exploration
In DaCe, a Schedule Tree is a representation designed for high-level optimizations. Schedule Trees were designed to help with optimizations like map or state fusion, which are cumbersome to achieve as passes on SDFGs.
While there are map & state fusion passes for SDFGs, they under-perform our expectations. Schedule Trees were thus drafted as an alternative representation to allow such high-level optimizations. Specifically, for NDSL, we are interested in two high-level optimizations
NDSL issue: https://github.com/GEOS-ESM/NDSL/issues/6
Schedule Trees can be generated from SDFGs. Rudimentary tools (e.g. printing a Schedule Tree) are provided to allow prototyping and (visual) verification of map merge and loop reordering.
The back transformation from Schedule Tree to SDFG currently missing. There is a work in progress branch https://github.com/spcl/dace/pull/1466 containing half the code needed to back from Schedule Trees to SDFGs.
Evaluation was done on a couple of simple examples. Code for these experiments lives in that branch.
def double_map(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(...):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
Plain Schedule tree representation
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
map __k in [0:4]:
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
map __k in [0:4]:
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
K-loop re-order and map merge
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
def double_map_with_different_intervals(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(1, None):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
Forced k-merge with (note how if
statements are pushed down through maps enabling map merges in I & J)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
TODO: We should skip the second if
because it spans the full range.
def mergeable_preserve_order(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(1, None):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
with computation(PARALLEL), interval(1, None):
out_field = in_field * 4
The order of computation blocks is preserved and only consecutive blocks are merged; even in case of force-merging in K.
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
Note the extra if statements. This could (maybe should) be cleaned up in a post-poc implementation.
def not_mergeable_k_dependency(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(1, None):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
with computation(PARALLEL), interval(1, None):
in_field = out_field
Force-merged the first two, but not the third because of data dependencies (read after write). This can be tuned in a post-poc implementation.
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __k in [1:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
in_field[__i, __j, __k] = tasklet(out_field[__i, __j, __k])
TODO
def loop_and_map(in_field: FloatField, out_field: FloatField):
with computation(FORWARD), interval(...):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
Sequential blocks (i.e. FORWARD
, BACKWARD
) are never merged with PARALLEL
computation blocks.
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
for __k = 0; (__k < (4 + 0)); __k = (__k + 1):
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
TODO: Even more aggressive k-merge as above
Can you specify the GPU launch parameters in the Schedule Tree? How does that translate?
When doing over-computation (in the case of map-merge), do we need to flag memlets as dynamic? Is it harmful if we lie to dace and hint at more read/writes in that case?
Can DaCe still re-use arrays in case we do very aggressive map merging? (Especially in the restricted space of GPU (v)RAM)
When using a temporary field, the "array definitions" are missing. How is that different for the other examples?
def tmp_field(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(...):
tmp = in_field
with computation(PARALLEL), interval(...):
out_field = tmp * 3
Plain Schedule Tree representation
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
map __k in [0:4]:
tmp[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
map __k in [0:4]:
out_field[__i, __j, __k] = tasklet(tmp[__i, __j, __k])
Horizontal regions show as "dynamic memlets". We currently exclude map merges with dynamic memlets in it. Is that a wise idea?
def horizontal_regions(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(...):
out_field = in_field * 2
with computation(PARALLEL), interval(...):
with horizontal(region[:, :-1]):
out_field = in_field
with horizontal(region[:-1, :]):
out_field = in_field
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
if ((__k >= 0) and (k <= 3) and (((__k - 0) % 1) == 0)):
out_field(dyn) [__i, __j, __k] = tasklet(in_field(dyn) [__i, __j, __k])
The Schedule Tree representation keeps its promise. The above evaluation shows that map fusion and loop reordering are possible in that representation. The proof of concept further shows that such optimization passes can be implemented elegantly due to the tree structure.
On the downside, one has to note the additional overhead of creating and maintaining the transformations between SDFGs and Schedule Trees. While it is theoretically possible to generate code directly from Schedule Trees, we would loose the possibility to run (existing) optimization passes after loop transformations and map fusion, which could unlock things like vectorization. In addition, current AI research is based on the SDFG representation. The conclusion is thus to live with the maintenance and transformation costs between Schedule Trees and SDFGs.
Forced k-merge (over-computation)
TODO: We should skip the second
if
because it spans the full range.
This was a simple check to if original and merged ranges are equal. For the case of the simple example
def double_map_with_different_intervals(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(1, None):
out_field = in_field
with computation(PARALLEL), interval(...):
out_field = in_field * 3
in_field = nview in_field[0:3, 0:3, 0:4] as (3, 3, 4)
out_field = nview out_field[0:3, 0:3, 0:4] as (3, 3, 4)
map __k in [0:4]:
map __tile_j, __tile_i in [0:3:8, 0:3:8]:
map __i, __j in [__tile_i:__tile_i + Min(8, 3 - __tile_i), __tile_j:__tile_j + Min(8, 3 - __tile_j)]:
if ((__k >= 1) and (k <= 3) and (((__k - 1) % 1) == 0)):
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
out_field[__i, __j, __k] = tasklet(in_field[__i, __j, __k])
we now only see one extra/inserted if
statement (as expected).
Cache miss occurs in the code, even the orchestrated, because we have IJK loops around all stencils instead of having K-loop shared by the stencils. Even after applying "map fusion", it's nowhere near enough to match the K-block/loop of Fortran.
For instance the entire code of
D_SW
is within a single K-loop in Fortran.Using Tal's new
stree
or directSDFG
manipulation, we shall implement an aggressive transform ondace
to move the K-loop as far as possible, including by introducingif
to block out indices that should be computed.Another code, called
partial expansion
which work onSDFG
alone and develop by Linus, exists inGT4Py.cartesian
as a branch.stree
branch: https://github.com/spcl/dace/tree/stree-to-sdfgFvTp2d
(code) or other)C_SW
and/orD_SW