spcl / dace

DaCe - Data Centric Parallel Programming
http://dace.is/fast
BSD 3-Clause "New" or "Revised" License
487 stars 121 forks source link

Simplify raises an exception on LoopRegion #1592

Closed edopao closed 2 months ago

edopao commented 2 months ago

The simplify pass throws the exception below in the GT4Py SDFGs containing LoopRegion nodes. This error goes away if I run sdutils.inline_loop_blocks(sdfg) before simplify().

Traceback (most recent call last):
  File "/opt/homebrew/Caskroom/miniconda/base/envs/edopao-gt4py/lib/python3.10/site-packages/networkx/classes/graph.py", line 2010, in bunch_iter
    for n in nlist:
TypeError: 'AccessNode' object is not iterable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/epaone/repo/dace/dace/sdfg/sdfg.py", line 2384, in simplify
    return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose).apply_pass(self, {})
  File "/Users/epaone/repo/dace/dace/transformation/passes/simplify.py", line 113, in apply_pass
    result = super().apply_pass(sdfg, pipeline_results)
  File "/Users/epaone/repo/dace/dace/transformation/pass_pipeline.py", line 547, in apply_pass
    newret = super().apply_pass(sdfg, state)
  File "/Users/epaone/repo/dace/dace/transformation/pass_pipeline.py", line 502, in apply_pass
    r = self.apply_subpass(sdfg, p, state)
  File "/Users/epaone/repo/dace/dace/transformation/passes/simplify.py", line 90, in apply_subpass
    ret = p.apply_pass(sdfg, state)
  File "/Users/epaone/repo/dace/dace/transformation/passes/analysis.py", line 162, in apply_pass
    if state.in_degree(anode) > 0:
  File "/Users/epaone/repo/dace/dace/sdfg/graph.py", line 696, in in_degree
    return self._nx.in_degree(node)
  File "/opt/homebrew/Caskroom/miniconda/base/envs/edopao-gt4py/lib/python3.10/site-packages/networkx/classes/reportviews.py", line 436, in __call__
    return self.__class__(self._graph, nbunch, weight)
  File "/opt/homebrew/Caskroom/miniconda/base/envs/edopao-gt4py/lib/python3.10/site-packages/networkx/classes/reportviews.py", line 421, in __init__
    self._nodes = self._succ if nbunch is None else list(G.nbunch_iter(nbunch))
  File "/opt/homebrew/Caskroom/miniconda/base/envs/edopao-gt4py/lib/python3.10/site-packages/networkx/classes/graph.py", line 2025, in bunch_iter
    raise exc
networkx.exception.NetworkXError: nbunch is not a node or a sequence of nodes.

Here is a test case to reproduce the issue.

import dace

M = 4
N = 10
dtype = dace.float64

def build_reduce_sdfg() -> dace.SDFG:
    acc_var = "__acc"
    neighbor_idx = "__idx"

    sdfg = dace.SDFG("reduce")
    sdfg.add_scalar(acc_var, dtype, transient=True)
    sdfg.add_array("A", (N,), dtype)
    sdfg.add_array("B", (1,), dtype)

    reduce_loop = dace.sdfg.state.LoopRegion(
        label="reduce",
        loop_var=neighbor_idx,
        initialize_expr=f"{neighbor_idx} = 0",
        condition_expr=f"{neighbor_idx} < {N}",
        update_expr=f"{neighbor_idx} = {neighbor_idx} + 1",
        inverted=False,
    )
    sdfg.add_node(reduce_loop)
    reduce_state = reduce_loop.add_state("loop")

    reduce_tasklet = reduce_state.add_tasklet(
        "reduce",
        {"x", "y"},
        {"res"},
        "res = x + y",
    )
    reduce_state.add_edge(
        reduce_state.add_access(acc_var),
        None,
        reduce_tasklet,
        "x",
        dace.Memlet(data=acc_var, subset="0"),
    )
    reduce_state.add_edge(
        reduce_state.add_access("A"),
        None,
        reduce_tasklet,
        "y",
        dace.Memlet(data="A", subset=neighbor_idx),
    )
    reduce_state.add_edge(
        reduce_tasklet,
        "res",
        reduce_state.add_access(acc_var),
        None,
        dace.Memlet(data=acc_var, subset="0"),
    )

    init_state = sdfg.add_state("init", is_start_block=True)
    sdfg.add_edge(init_state, reduce_loop, dace.InterstateEdge())
    init_state.add_edge(
        init_state.add_tasklet("write", {}, {"val"}, "val = 0"),
        "val",
        init_state.add_access(acc_var),
        None,
        dace.Memlet(f"{acc_var}[0]")
    )

    exit_state = sdfg.add_state("exit")
    sdfg.add_edge(reduce_loop, exit_state, dace.InterstateEdge())
    exit_state.add_nedge(
        exit_state.add_access(acc_var),
        exit_state.add_access("B"),
        dace.Memlet(data=acc_var, subset="0", other_subset="0"),
    )

    return sdfg

def test_gt_loop():
    reduce_sdfg = build_reduce_sdfg()

    sdfg = dace.SDFG("gt_loop")
    sdfg.add_array("inp", (M, N), dtype)
    sdfg.add_array("out", (M,), dtype)

    state = sdfg.add_state()
    me, mx = state.add_map("row_loop", dict(i=f"0:{M}"))
    nsdfg = state.add_nested_sdfg(reduce_sdfg, sdfg, {"A"}, {"B"})
    state.add_memlet_path(state.add_access("inp"), me, nsdfg, dst_conn="A", memlet=dace.Memlet(f"inp[i, 0:{N}]"))
    state.add_memlet_path(nsdfg, mx, state.add_access("out"), src_conn="B", memlet=dace.Memlet(f"out[i]"))

    sdfg.validate()
    sdfg.simplify()
phschaad commented 2 months ago

The issue has been fixed in https://github.com/spcl/dace/pull/1475.

One additional note: The linked PR introduces a flag on SDFGs dace.sdfg.sdfg.SDFG.using_experimental_blocks. This is a boolean property that can be set to indicate that an SDFG contains the still being worked on features like LoopRegions. It is important to make sure that property is set when using LoopRegions, because certain passes are not yet compatible with them or have not yet been adapted to handle them. Setting this flag will ensure no such passes (or transformations) get executed.