spcl / dace

DaCe - Data Centric Parallel Programming
http://dace.is/fast
BSD 3-Clause "New" or "Revised" License
491 stars 127 forks source link

Double index lookup in tasklet breaks under simplification #1281

Closed gronerl closed 1 week ago

gronerl commented 1 year ago

The following small graph with a indirection (double index lookup) in a tasklet breaks under simplify passes:

import dace

sdfg = dace.SDFG('test_')
state = sdfg.add_state()

sdfg.add_array('A', shape=(10,), dtype=dace.float64)
sdfg.add_array('table', shape=(10, 2), dtype=dace.int64)
sdfg.add_array('B', shape=(10,), dtype=dace.float64)
sdfg.add_scalar('idx', dace.int64, transient=True)
idx_node = state.add_access('idx')
set_tlet = state.add_tasklet('set_idx', code="_idx=0", inputs={}, outputs={"_idx"})
state.add_mapped_tasklet('map', map_ranges={'i':"0:10"}, inputs={'inp': dace.Memlet("A[0:10]"),'_idx': dace.Memlet('idx[0]'), 'indices': dace.Memlet('table[0:10, 0:2]')},code="out = inp[indices[i,_idx]]", outputs={'out': dace.Memlet("B[i]")},  external_edges=True,
                         input_nodes={'idx':idx_node})

state.add_edge(set_tlet, '_idx', idx_node, None, dace.Memlet('idx[0]'))

Namely, the indirection gets moved to an edge (the subset reading "A[indices(i, 0)]") by simplification which is not a valid subset.

BenWeber42 commented 1 year ago

I looked into this issue and I believe this is due to the following behavior:

from dace import symbolic
symbolic.pystr_to_symbolic("a[i]").free_symbols
# gives {i}
symbolic.pystr_to_symbolic("indices[i,_idx]").free_symbols
# gives {_idx, i}

With the given tasklet code:

out = inp[indices[i,_idx]]

This means, that DaCe thinks the expression indices[i,_idx] is independent of indices. Because i and _idx are given earlier, DaCe pushes the expression indices[i,_idx] (incorrectly) into the memlet.

I think we want a behavior such that:

symbolic.pystr_to_symbolic("a[i]").free_symbols
# gives {i, a}

I am currently looking into how to best implement this. There is a risk that changing this could break other parts of the codebase...

tbennun commented 1 year ago

Since the issue only happens during simplification, and the only pass that creates such memlets is ScalarToSymbolPromotion, shouldn’t it be there? Specifically TaskletIndirectionPromoter should avoid promoting expressions that will make invalid symbolic expressions, e.g., containing an ast.Subscript in this case (but there may be other invalid cases?).

BenWeber42 commented 1 year ago

What is considered a valid symbolic expression? And what is considered a valid expression for a memlet?

I guess for memlets the expressions are constrained to dace.properties.SubsetProperty? For symbolic expressions, it's determined by dace.symbolic.pystr_to_symbolic?

Because both seem to allow the double index currently. Or at least neither give a validation error for a double index expression.

tbennun commented 1 year ago

A symbolic expression is a combination of symbols and operations on those symbols. The operations are arithmetic or logical unary/binary/ternary operators on those symbols, or a set of predefined math functions in SymPy (i.e., ones that have a set of identities that can be used to simplify them) on those symbolic expressions. Array accesses and attribute queries are not considered symbolic expressions (as arrays/data containers are not symbols).

A memlet must have a single data container it's accessing and one or two subsets (which are symbolic expressions as per the above definition).

Inter-state edges extend this definition lightly (because they can access data containers), and they can also involve array accesses (defined as the array name becoming a function call, e.g., A(i, 8*j+2)) as well as attributes (Attr(A, 'x')). The reason we did that was to allow us to simplify expressions in the inter-state edges. I am not entirely sure this is necessary at all.


For the last point, if validation doesn't fail on bad expressions, it should.

BenWeber42 commented 11 months ago

I am currently testing this solution: #1406

BenWeber42 commented 11 months ago

Properly fixing this issue requires changes/improvements for querying symbols in symbolic expressions. However, that is currently not working well and would likely require substantial work (see #1417).

Additionally, validation is missing a check for this bug. I am tracking this here: #1418

Splitting these tasks allows us to focus on a practical work-around for now (likely #1406).