spcl / dace

DaCe - Data Centric Parallel Programming
http://dace.is/fast
BSD 3-Clause "New" or "Revised" License
487 stars 121 forks source link

Memlet Error in RedundantArray (simplify) #1595

Open philip-paul-mueller opened 2 months ago

philip-paul-mueller commented 2 months ago

I found a bug in simplify, as far as I can tell it is located it is inside RedundantArray The error is about a Memlet, that performs some reshaping. A further issue that makes the whole thing complicated is, that the error depends on the processing order, the bug only appears in some cases, which makes it hard to debug. I was able to find a minimal example that triggers the bug, but not always, so you have to run it multiple times.

import dace

"""
Minimal example for bug in `RedundantArray`

Essentially the SDFG perfroms the following computation.
def foo(A: dace.float64[6, 6, 6]) -> dace.float64[36, 1, 6]:
    return A.reshape((36, 1, 6))

I located the error in the `RedundantArray` transformation, that screws up with the
Memlets. Furthermore, the input array is in FORTRAN order, but it should also work
with C order since it happens before code generation.

I M P O R T A N T
=================
The bug is not deterministic, it depends on the processing order!
It happens if node `a` is removed instead of Node `b`, never mind that that node could
also be removed. If you run the script and it does not fail, try again until it happens.
"""

sdfg = dace.SDFG("invalid_sdfg")

_, input_desc = sdfg.add_array(
        "input",
        shape=(6, 6, 6),
        transient=False,
        strides=(1, 6, 36),
        dtype=dace.float64,
)
_, a_desc = sdfg.add_array(
        "a",
        shape=(6, 6, 6),
        transient=True,
        strides=(36, 6, 1),
        dtype=dace.float64,
)
_, b_desc = sdfg.add_array(
        "b",
        shape=(36, 1, 6),
        transient=True,
        strides=(6, 6, 1),
        dtype=dace.float64,
)
_, output_desc = sdfg.add_array(
        "output",
        shape=(36, 1, 6),
        transient=False,
        strides=(6, 6, 1 ),
        dtype=dace.float64,
)

state = sdfg.add_state("state", is_start_block=True)
input_an = state.add_access("input")
a_an = state.add_access("a")
b_an = state.add_access("b")
output_an = state.add_access("output")

state.add_edge(
        input_an,
        None,
        a_an,
        None,
        dace.Memlet.from_array("input", input_desc),
)

state.add_edge(
        a_an,
        None,
        b_an,
        None,
        dace.Memlet.simple(
            "a",
            subset_str="0:6, 0:6, 0:6",
            other_subset_str="0:36, 0, 0:6",
        )
)

state.add_edge(
        b_an,
        None,
        output_an,
        None,
        dace.Memlet.from_array("b", b_desc),
)

sdfg.validate()
sdfg.simplify(validate=False)

if len(sdfg.arrays) == 2:
    print("All transients were removed, this is new, will it fail?")

elif "a" not in sdfg.arrays:
    print("Array `a` was removed, in the past this indicated that the SDFG is invalid."
          "\nLet's start validation to see what happens.")

elif "b" not in sdfg.arrays:
    print("Array `b` was removed, in the past such an SDFG was valid, try again.")

else:
    print("Something is fishy.")

sdfg.validate()
philip-paul-mueller commented 2 months ago

A deterministic test can be found here, but it should be fixed in PR1603