iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 614 forks source link

Track resource update regions and elide/minimize unneeded fills. #6972

Open benvanik opened 3 years ago

benvanik commented 3 years ago

The new stream ops give us enough information to symbolically identify discards:

%13 = stream.async.splat %cst_1 : f32 -> !stream.resource<*>{%c296}
%14 = stream.async.update %12, %13[%c0 to %c256] : !stream.resource<*>{%c256} -> %13 as !stream.resource<*>{%c296}
%15 = stream.async.update %arg4, %14[%c256 to %c296] : !stream.resource<*>{%c40} -> %14 as !stream.resource<*>{%c296}

Here we know by [%c0, %c256) + [%c256, %296) that we are overwriting the entire resource. We should be able to transform this into:

%13 = stream.async.alloca : !stream.resource<*>{%c296}
%14 = stream.async.update %12, %13[%c0 to %c256] : !stream.resource<*>{%c256} -> %13 as !stream.resource<*>{%c296}
%15 = stream.async.update %arg4, %14[%c256 to %c296] : !stream.resource<*>{%c40} -> %14 as !stream.resource<*>{%c296}

The only value we need to know is constant here is 0 and the rest can be pure equality.

Another case of this is turning the splat into fills if we do have gaps either on the interior or at the tail (common in padding). For example:

%13 = stream.async.splat %cst_1 : f32 -> !stream.resource<*>{%c296}
%14 = stream.async.update %12, %13[%c0 to %c256] : !stream.resource<*>{%c256} -> %13 as !stream.resource<*>{%c296}

->

%13 = stream.async.alloca : !stream.resource<*>{%c296}
%14 = stream.async.update %12, %13[%c0 to %c256] : !stream.resource<*>{%c256} -> %13 as !stream.resource<*>{%c296}
%15 = stream.async.fill %cst_1, %14[%c256 to %c296] : f32 -> %14 as !stream.resource<*>{%c296}

Could be done as a canonicalization on update.

benvanik commented 2 years ago

ESRGAN uses a producer[] -> insert[] pattern and results in a lot of fills. Relates to #7729. In these cases we have enough information to track that the entire tensor is overwritten and that the splat can be dropped.

    %1350 = stream.async.splat %c0_i8 : i8 -> !stream.resource<transient>{%c4285440}
    %1351 = stream.async.update %1325, %1350[%c0 to %c1428480] : !stream.resource<transient>{%c1428480} -> %1350 as !stream.resource<transient>{%c4285440}
    %1352 = stream.async.update %1328, %1351[%c1428480 to %c2142720] : !stream.resource<transient>{%c714240} -> %1351 as !stream.resource<transient>{%c4285440}
    %1353 = stream.async.update %1334, %1352[%c2142720 to %c2856960] : !stream.resource<transient>{%c714240} -> %1352 as !stream.resource<transient>{%c4285440}
    %1354 = stream.async.update %1341, %1353[%c2856960 to %c3571200] : !stream.resource<transient>{%c714240} -> %1353 as !stream.resource<transient>{%c4285440}
    %1355 = stream.async.update %1349, %1354[%c3571200 to %c4285440] : !stream.resource<transient>{%c714240} -> %1354 as !stream.resource<transient>{%c4285440}
benvanik commented 2 years ago

Note to self: benvanik-concurrent-copies has a WIP implementation of a general write elision pass that does the tracking to find these chains of in-place operations on adjacent ranges. It should really be written as a data-flow analysis, though, and is going to need a rewrite. This would allow us to track divergent execution paths that fill the same regions and track across function calls.