NOAA-GFDL / NDSL

NOAA NASA Domain Specific Language middleware layer
6 stars 8 forks source link

[Orchestrated] Signature & empty code issues #70

Open FlorianDeconinck opened 1 week ago

FlorianDeconinck commented 1 week ago

Description

There's a series of bug related to signature at GT level and declared symbols/scalar/fields at DaCe wrapper SDFG. All of those bugs live in the bridge between gt4py & dace.

Those can be classified in 3 groups:

Most of those behavior are linked to the prune_unused_argument pass of GT4Py which is called at the very beginning of GTIR. While this is clearly not the design (passes should be pushed down to OIR or backend IR) this was done to deal with some of those issues. Plain removing the prune pass (which could be done considering it gives little to no performance improvement) does not lead to fixing.

In the comments below we will put down 3 examples that showcase the issues (either plain ndsl or relying on pyfv3) and some patches that fixes some bugs but creates other.

To Reproduce See comment.

FlorianDeconinck commented 1 week ago

Test file to be dropped in ndsl (when the reference to pyFV3 is removed)

from ndsl.stencils.corners import fill_corners_dgrid_defn
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM, X_INTERFACE_DIM, Y_INTERFACE_DIM
from ndsl.dsl.typing import Float, FloatField
from ndsl import orchestrate, StencilFactory, DaceConfig
from gt4py.cartesian.gtscript import computation, PARALLEL, interval

class OrchestratedCorner:
    def __init__(self, stencil_factory: StencilFactory) -> None:
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(communicator=None, backend=stencil_factory.backend),
        )
        origin, domain = stencil_factory.grid_indexing.get_origin_domain(
            dims=[X_DIM, Y_DIM, Z_DIM]
        )
        axes_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain)

        self.corner_stencil = stencil_factory.from_origin_domain(
            fill_corners_dgrid_defn,
            externals=axes_offsets,
            origin=origin,
            domain=domain,
        )

    def __call__(self, x, y):
        self.corner_stencil(x, x, y, y, 1.0)

def test_empty_corners():
    stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
        12, 12, 5, 0
    )
    # Make the
    stencil_factory.grid_indexing.south_edge = False
    stencil_factory.grid_indexing.north_edge = False
    stencil_factory.grid_indexing.west_edge = False
    stencil_factory.grid_indexing.east_edge = False
    stencil_factory.grid_indexing.axis_offsets

    x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")

    orch_corner = OrchestratedCorner(stencil_factory)

    orch_corner(x, y)

def unusued_parameter_stencil(
    field: FloatField,  # type: ignore
    result: FloatField,  # type: ignore
    weight: Float,  # type: ignore
):
    with computation(PARALLEL), interval(...):
        result = field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]

class OrchestratedUnusedParameter:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(communicator=None, backend=stencil_factory.backend),
        )
        origin, domain = stencil_factory.grid_indexing.get_origin_domain(
            dims=[X_DIM, Y_DIM, Z_DIM]
        )

        self.unused_stencil = stencil_factory.from_origin_domain(
            unusued_parameter_stencil,
            origin=origin,
            domain=domain,
        )

    def __call__(self, x, y):
        self.unused_stencil(x, y, 1.0)

def test_unused_parameters():
    stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
        12, 12, 5, 2
    )

    x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")

    orch_unused = OrchestratedUnusedParameter(stencil_factory)

    orch_unused(x, y)

def unusued_field_stencil(
    field: FloatField,  # type: ignore
    other_field: FloatField,  # type: ignore
    result: FloatField,  # type: ignore
):
    with computation(PARALLEL), interval(...):
        result = field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]

class OrchestratedunusedField:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(communicator=None, backend=stencil_factory.backend),
        )
        origin, domain = stencil_factory.grid_indexing.get_origin_domain(
            dims=[X_DIM, Y_DIM, Z_DIM]
        )

        self.unused_stencil = stencil_factory.from_origin_domain(
            unusued_field_stencil,
            origin=origin,
            domain=domain,
        )

    def __call__(self, x, unused_field, y):
        self.unused_stencil(
            x,
            unused_field,
            y,
        )

def test_unused_field():
    stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
        12, 12, 5, 2
    )

    x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    x_unused = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")

    orch_unused = OrchestratedunusedField(stencil_factory)

    orch_unused(x, x_unused, y)

if __name__ == "__main__":
    test_unused_parameters()
    test_empty_corners()
    test_unused_field()
FlorianDeconinck commented 1 week ago

Patches to be applied

diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py
index 7608fcd5..e4699b3b 100644
--- a/src/gt4py/cartesian/backend/dace_backend.py
+++ b/src/gt4py/cartesian/backend/dace_backend.py
@@ -234,6 +234,24 @@ def _sdfg_add_arrays_and_edges(
                     None,
                     dace.Memlet(name, subset=dace.subsets.Range(ranges)),
                 )
+        elif isinstance(array, dace.data.Scalar):
+            wrapper_sdfg.add_scalar(name, dtype=array.dtype, storage=array.storage)
+            if name in inputs:
+                state.add_edge(
+                    state.add_read(name),
+                    None,
+                    nsdfg,
+                    name,
+                    dace.Memlet(name),
+                )
+            if name in outputs:
+                state.add_edge(
+                    nsdfg,
+                    name,
+                    state.add_write(name),
+                    None,
+                    dace.Memlet(name),
+                )

 def _sdfg_specialize_symbols(wrapper_sdfg, domain: Tuple[int, ...]):
diff --git a/src/gt4py/cartesian/backend/dace_lazy_stencil.py b/src/gt4py/cartesian/backend/dace_lazy_stencil.py
index 2b3cf6fe..0c614ad8 100644
--- a/src/gt4py/cartesian/backend/dace_lazy_stencil.py
+++ b/src/gt4py/cartesian/backend/dace_lazy_stencil.py
@@ -15,6 +15,7 @@ from gt4py.cartesian.backend.dace_backend import SDFGManager
 from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject, add_optional_fields
 from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir
 from gt4py.cartesian.lazy_stencil import LazyStencil
+from gt4py.cartesian.gtc.passes.gtir_prune_unused_parameters import prune_unused_parameters

 if TYPE_CHECKING:
@@ -26,6 +27,7 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
         if "dace" not in builder.backend.name:
             raise ValueError("Trying to build a DaCeLazyStencil for non-dace backend.")
         super().__init__(builder=builder)
+        self.signature = []

     @property
     def field_info(self) -> Dict[str, Any]:
@@ -47,7 +49,8 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
     def __sdfg__(self, *args, **kwargs) -> dace.SDFG:
         sdfg_manager = SDFGManager(self.builder)
         args_data = make_args_data_from_gtir(self.builder.gtir_pipeline)
-        arg_names = [arg.name for arg in self.builder.gtir.api_signature]
+        assert self.signature != []
+        arg_names = self.signature
         assert args_data.domain_info is not None
         norm_kwargs = DaCeStencilObject.normalize_args(
             *args,
@@ -69,5 +72,9 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
         return {}

     def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]:
-        args = [arg.name for arg in self.builder.gtir.api_signature]
-        return (args, [])
+        if self.signature == []:
+            self.signature = [
+                str(p)
+                for p in self.builder.gtir_pipeline.apply([prune_unused_parameters]).param_names
+            ]
+        return (self.signature, [])
diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
index dba6c5a7..9dd57290 100644
--- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
+++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
@@ -150,7 +150,7 @@ class OirSDFGBuilder(eve.NodeVisitor):
                     debuginfo=dace.DebugInfo(0),
                 )
             else:
-                ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
+                ctx.sdfg.add_scalar(param.name, dtype=data_type_to_dace_typeclass(param.dtype))

         for decl in node.declarations:
             dim_strs = [d for i, d in enumerate("IJK") if decl.dimensions[i]] + [
FlorianDeconinck commented 1 week ago

Working solution seems to be updating the symbol_mapping of the library node pre-expansion (StencilComputation)

Branch under test: https://github.com/FlorianDeconinck/gt4py/tree/cartesian/fix/missing_parameter