ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
176 stars 80 forks source link

Reduce compile time by reducing calls to `compute_shape()` for each IR transformation #3233

Open umangyadav opened 1 month ago

hgaspar commented 1 month ago

The proposal is to tag a node as "modified", but without explicitly calling immediately compute shape for the whole graph. Propagate_shapes then becomes an explicit "pass"

Interesting example:

Gemm, feeding into pathA and pathB. Assume that on path A there is a transposition, and somehow we decide, when we are transversing the Gemm, to implement the transposition in pathA by modifying the output strides of the gemm. At the same time, we should add a node on B, call it lazy_reshape (or whatever) on pathB. That is enough of an action, no need to fully propagate the shape change immediately all the way through.

CharlieL7 commented 1 month ago

I had a similar idea to this where we have a lazy_add_instruction() or lazy_replace_instruction() that doesn't propagate the op.compute_shape() calls until the end of the matcher. Another way to do this would be to make a way to add a series of instructions in one go.

bpickrel commented 1 month ago

How many times is compute_shape() being called now? I didn't think it was very computation-intensive. Is it called for every instruction in the graph each time any instruction is changed?

umangyadav commented 1 month ago

Another way to do this would be to make a way to add a series of instructions in one go.

It may not be amenable to current way migraphx matchers are written. Instructions need to be present at the time of matching.

umangyadav commented 1 month ago

I didn't think it was very computation-intensive.

It is if you look at computing window for convolution or let's ROIAlign or NMS operations. For most of the other operations is it not compute heavy but problem lies in how many times compute_shape() is called. MIGraphX currently calls compute_shape() each time it mutates the IR in some form.

krzysz00 commented 1 month ago

IIRC, the issue I was pointing at is that compute_shape{} calls are O(n) - that is, if I replace @X with @Y, then its users (say @A and @B) need to recompute their shapes, and if those shapes change, then their users need to recompute ... and so on and so forth, every time the IR is mutated

krzysz00 commented 1 month ago

To write the proposal out again, names subject to bikeshedding:

Currently, let's say I have (and I'll use MLIR syntax for the shapes)

@3 = dot(@1, @2)  : <4x2xf16, 2x1>, <2x3xf16, 3x1> -><4x3xf16, 3x1>
@4 = transpose(@3) : <4x3xf16, 3x1> -> <3x4xf16, 4x1>
@5 = relu(@4) : <3x4xf16, 4x1> -> <3x4xf16, 4x1>

I can rewrite this to

// Note the different logical shape and strides. This is an equivalent phrasing of the transpose.
@3 = dot(@1, @2)  : <4x2xf16, 2x1>, <2x3xf16, 3x1> -><4x3xf16, 1x4>
@5 = relu(@3) : <4x3xf16, 1x4> -> <4x3xf16, 1x4>

This is, as far as I'm aware, a perfectly legal, semantics-preserving rewrite in MIGraphX. However, the performance issue is that, as soon as you make this mutation, you touch off a massive pile of compute_shape calls.

I propose that MIGraphX IR is extended so that the rewrite I'd have to do is

@3 = dot(@1, @2)  : <4x2xf16, 2x1>, <2x3xf16, 3x1> -><4x3xf16, 1x4>
@3b = pending_shape_change(@3) : <4x3xf16, 1x4> -> <4x3xf16, 4x1> 
@4 = transpose(@3b) : <4x3xf16, 3x1> -> <3x4xf16, 4x1>
@5 = relu(@4) : <3x4xf16, 4x1> -> <3x4xf16, 4x1>

Note that, because we've done this, the user of @3 doesn't need to recompute its shapes yet - it just has the value it's acting on changed.

Then, at a future point, you do rewrites that "apply" pending_shape_change operations until fixpoint.

In this example, the shape-update logic for transpose will be able to fold the transpose away because it's now trivial

@3 = dot(@1, @2)  : <4x2xf16, 2x1>, <2x3xf16, 3x1> -><4x3xf16, 1x4>
@4b = pending_shape_change(@3) : <4x3xf16, 1x4> -> <3x4xf16, 3x1>
@5 = relu(@4b) : <3x4xf16, 4x1> -> <3x4xf16, 4x1>

And then relu just propagates the update forward

@3 = dot(@1, @2)  : <4x2xf16, 2x1>, <2x3xf16, 3x1> -><4x3xf16, 1x4>
@4b = pending_shape_change(@3) : <4x3xf16, 1x4> -> <3x4xf16, 3x1>
@5 = relu(@3) : <4x3xf16, 1x4> -> <4x3xf16, 1x4>
@5b = pending_shape_change(@5) : <4x3xf16, 1x4> -> <3x4xf16, 3x1>

The advantage to this is that propagations are done lazy after all rewrite patterns from a collection are done.

On top of that, you get a useful invariant you could start enforcing: when replacing @x with @y, if the shape of @x doesn't match the shape of @y, either this automatically inserts a pending_shape_change or, more interestingly, is an error (and the pending_shape_change auto-insertion is a helper wrapper)

bpickrel commented 1 month ago

This implies that the pending_shape_change isn't really an MigraphX operation in a sense that has never been defined until now, because it's not something for the GPU to do or to be translated to HIP. Instead, it's a sort of placeholder that fits in the program graph but is consumed and removed at compile time. It's too early for me to give an opinion on whether Krzysztof's implementation is a good idea--I wonder if the same idea could be used for other handling enhancements. But the logically equivalent behavior could also be coded as additions to the existing MigraphX operation class or its contents, as we've done with enhancements like dynamic shapes. In other words, tag certain ops as being "lazy shape" and add an automatic compiler pass to resolve them.

krzysz00 commented 1 month ago

Well, yes, the MIGraphX IR is your IR and you're allowed to define any operation you'd like - including something like pending_shape_change whose semantics are "you either have to fold this into its users or materialize it as a reshape/transpose/... if they don't know how to deal with it before you try to run the program".

I'm taking this from MLIR's builtin.unrealized_conversion_cast, which is used for similar deferred resolution of type changes during IR mutations.

bpickrel commented 1 month ago

re your last paragraph ~On top of that, you get a useful invariant...": do you think we can write a handler that's smart enough to identify which shape changes are both legal and have some sort of limited scope such that not resolving them right away is safe? This sounds similar in concept to our current matchers, but the matcher language is very script-like, inflexible, and requires intensive hand-tweaking (as we all know; after all that's what pays the rent for many of us). In the worst case, the checking would be so complex that it's just as expensive as the current eager shape updates.

krzysz00 commented 1 month ago

Well, the point is that every time you change the type ("shape") of a value while doing a replacement, you add a pending_shape_change - or perhaps pending_type_change. Then you have a resolvePendingShapeChanges()that runs until fixpoint and rewrites any such operation either by folding it into its consumers or into some concrete shape-change like a transpose or a reshape

(and as an efficiency extension, if you know you're doing a bunch of global type changes - let's say you've got a pass that changes all floats to halfs, you'd tell the matcher to look through the pending_shape_change on your inputs to get the half outputs directly, thus making all the pending_shape_change ops dead code.

(If you're going bottom-up with that, you'd be emitting converts for your inputs anyway, and

@0 = ... -> half, {...}, {...}
@1 = pending_shape_change(@0) -> float, {...}, {...}
@2 = convert(@1) -> half, {...}, {...}

can have @1 and @2 folded away