spcl / dace

DaCe - Data Centric Parallel Programming
http://dace.is/fast
BSD 3-Clause "New" or "Revised" License
487 stars 121 forks source link

Codegen: View Shadows Array in NestedSDFG #1612

Open philip-paul-mueller opened 2 months ago

philip-paul-mueller commented 2 months ago

I have a strange error, after some experiments I have concluded the following:

In that case the code generator will generate for the nested SDFG:

    X = new float[N];

without declaring X, i.e. the line float* X, that should precede the allocation, is missing.

Currently I have only a big SDFG that has this pattern. The error vanishes, once I rename the view in the top level SDFG or turn it into an array. I added the SDFG and scripts that shows I "fixed" the error.

I will also try to come up with a smaller example. example.zip

philip-paul-mueller commented 2 months ago

I have now generated a reproducer:

import dace

dtype = dace.float64
shape = (10, 10)

def make_nsdfg():
    """
    Generates a nested SDFG with that runs the following calculation:

    ```python
    def nested_comp(
            input1: dace.float64[10, 10],
            input2: dace.float64[10, 10],
    ) -> dace.float64[10, 10]:
        X = input1 + input2
        return X * input2
The inportant thing here is that the temporary is called `X`.
"""
sdfg = dace.SDFG("Nested_SDFG")

array_names = [("input1", False), ("input2", False), ("output", False), ("X", True)]
array_descs = {}
for array_name, is_transient in array_names:
    _, desc = sdfg.add_array(
        array_name,
        dtype=dtype,
        shape=shape,
        transient=is_transient,
    )
    array_descs[array_name] = desc

state1 = sdfg.add_state("init_state", is_start_block=True)
state2 = sdfg.add_state_after(state1, "out_state")

state1.add_mapped_tasklet(
    "first_addition",
    map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
    code="__out = __in1 + __in2",
    inputs={"__in1": dace.Memlet("input1[__i0, __i1]"),
            "__in2": dace.Memlet("input2[__i0, __i1]"),
    },
    outputs={"__out": dace.Memlet("X[__i0, __i1]")},
    external_edges=True,
)
state2.add_mapped_tasklet(
    "second_addition",
    map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
    code="__out = __in1 * __in2",
    inputs={"__in1": dace.Memlet("X[__i0, __i1]"),
            "__in2": dace.Memlet("input2[__i0, __i1]"),
    },
    outputs={"__out": dace.Memlet("output[__i0, __i1]")},
    external_edges=True,
)
sdfg.validate()
sdfg.simplify()

return sdfg

def make_main_sdfg( apply_fix: int | None = None, ): """Generate the failing SDFG.

Essentially the computation:

```
def comp(
        input1: dace.float64[10, 10],
        input2: dace.float64[10, 10],
        input3: dace.float64[10, 10],
) -> dace.float64[10, 10]:
    nested_output = nested_comp(input1, input2)
    output = np.zeros_like(nested_output)
    X = output.view()
    X = nested_output * input3

    return output
```

The important thing is that the view in the above computation has the same
name as the temporary that is used inside the nested computation.
This SDFG will pass validation and code generation, but it is not able
compile it. The aviable fixes, see below, esspecially `2` suggests that
it is a bug in code generator.

It is possible to apply several fixes for this issue:
- `1`: Means that the nested SDFG is put into a sperate state.
- `2`: Give the view in the top SDFG a different name.
- `3`: Using an array instead of a view in the top SDFG.
"""

sdfg = dace.SDFG("main_SDFG")
array_names = [("input1", False), ("input2", False), ("input3", False), ("output", False), ("nested_output", True)]
array_descs = {}
for array_name, is_transient in array_names:
    _, desc = sdfg.add_array(
        array_name,
        dtype=dtype,
        shape=shape,
        transient=is_transient,
    )
    array_descs[array_name] = desc

if(apply_fix == 2):
    view_name = "not_X"
else:
    # Same name as inside the Nested SDFG.
    view_name = "X"

# Now generate the view that we need
if(apply_fix == 3):
    sdfg.add_array(
        view_name,
        dtype=dtype,
        shape=shape,
        transient=True,
    )
else:
    sdfg.add_view(
        view_name,
        shape=shape,
        dtype=dtype,
    )

state1 = sdfg.add_state("nested_host_state_init", is_start_block=True)
nested_sdfg = make_nsdfg()

nested_inputs = {
    "input1": "input1",
    "input2": "input2"
}
nested_outputs = {
    "nested_output": "output",
}
nsdfg = state1.add_nested_sdfg(
    nested_sdfg,
    parent=sdfg,
    inputs=set(nested_inputs.values()),
    outputs=set(nested_outputs.values()),
)

for in_parent, in_nested in nested_inputs.items():
    state1.add_edge(
        state1.add_read(in_parent),
        None,
        nsdfg,
        in_nested,
        dace.Memlet.from_array(in_parent, array_descs[in_parent]),
    )
nested_outputs_ac = []
for out_parent, out_nested in nested_outputs.items():
    nested_outputs_ac.append(state1.add_access(out_parent))
    state1.add_edge(
        nsdfg,
        out_nested,
        nested_outputs_ac[-1],
        None,
        dace.Memlet.from_array(out_parent, array_descs[out_parent]),
    )
assert len(nested_outputs_ac) == 1

if(apply_fix == 1):
    state2 = sdfg.add_state_after(state1, "second_main_state")
    nested_output_ac = state2.add_access("nested_output")
else:
    state2 = state1
    nested_output_ac = nested_outputs_ac[0]

state2.add_mapped_tasklet(
    "second_addition_in_map",
    map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
    code="__out = __in1 * __in2",
    inputs={"__in1": dace.Memlet("nested_output[__i0, __i1]"),
            "__in2": dace.Memlet("input3[__i0, __i1]"),
    },
    input_nodes={"nested_output": nested_output_ac, "input3": state2.add_access("input3")},
    outputs={"__out": dace.Memlet(view_name + "[__i0, __i1]")},
    external_edges=True,
)

# Find the access node of the view
all_access_nodes = [node for node in state2.nodes() if isinstance(node, dace.nodes.AccessNode)]
view_access_node = next(node for node in all_access_nodes if node.data == view_name)

state2.add_edge(
    view_access_node,
    "views",
    state2.add_write("output"),
    None,
    dace.Memlet.from_array("output", array_descs["output"]),
)
sdfg.validate()

return sdfg

def can_be_compiled( sdfg: dace.SDFG, ) -> bool: import warnings with warnings.catchwarnings(): warnings.simplefilter("ignore") try: = sdfg.compile() except (dace.codegen.exceptions.CompilationError, dace.codegen.exceptions.CodegenError): return False return True

Now test everything

for fix in [None, 1, 2, 3]: sdfg = make_main_sdfg(fix)

if fix is None:
    # No fix so we expect it to fail
    assert not can_be_compiled(sdfg), "It seems the bug has vanisched, did you fixed it."
else:
    assert can_be_compiled(sdfg), f"Expected that fix `{fix}` circumvent the problem."