llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29k stars 11.95k forks source link

[mlir][linalg] Make RemoveOutsDependency pattern preserve destination passing style #73745

Open MaheshRavishankar opened 11 months ago

MaheshRavishankar commented 11 months ago

The pattern RemoveOutsDependency is meant to do the following

%0 = <some operation>
%1 = linalg.generic {
    iterator_types = ["parallel", "parallel"],  indexing_maps = [...]}
    ins(....) outs(%0 : tensor<?x?xf32>) {...}

The linalg.generic is all parallel. So the use of %0 -> %1 is a "false" dependency cause the actual values of %0 arent used in %1. They are only really present to represent the shape of the output. So it is presumably legal to convert this to

%0 = <some operation>
%empty = tensor.empty(...)
%1 = linalg.generic {
    iterator_types = ["parallel", "parallel"],  indexing_maps = [...]}
    ins(....) outs(%empty : tensor<?x?xf32>) {...}

where the dependency is now broken. This has advantages in terms of ability to DCE code and also helps fusion.

This pattern though was added before DestinationPassingStyleOpInterface. One thing the above pattern does is break destination passing style. A better approach is to no replace outs with tensor.empty(), instead restrict this pattern to do the following. If %0 instead of "some operation" is "some destination passing style operation)

%0 = <some destination passing op> init(%dest : tensor<?x?xf32>) ...
%1 = linalg.generic {
    iterator_types = ["parallel", "parallel"],  indexing_maps = [...]}
    ins(....) outs(%0 : tensor<?x?xf32>) {...}

and %dest is the operand that is "tied" to %0, the pattern above should convert this to

%0 = <some destination passing op> init(%dest : tensor<?x?xf32>) ...
%1 = linalg.generic {
    iterator_types = ["parallel", "parallel"],  indexing_maps = [...]}
    ins(....) outs(%dest : tensor<?x?xf32>) {...}

This still breaks the false %0 -> %1 dependency. It does introduce a %dest to %1 dependency, but if that is also defined by a destination passing style op, then the pattern can further apply, propagating as far up the use-def chain as possible.

llvmbot commented 11 months ago

@llvm/issue-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

The pattern [RemoveOutsDependency](https://github.com/llvm/llvm-project/blob/f73844d92b36cb6801ac50ea721f4ba29b35d7a9/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp#L1802) is meant to do the following ``` %0 = <some operation> %1 = linalg.generic { iterator_types = ["parallel", "parallel"], indexing_maps = [...]} ins(....) outs(%0 : tensor<?x?xf32>) {...} ``` The `linalg.generic` is all parallel. So the use of `%0 -> %1` is a "false" dependency cause the actual values of `%0` arent used in `%1`. They are only really present to represent the shape of the output. So it is presumably legal to convert this to ``` %0 = <some operation> %empty = tensor.empty(...) %1 = linalg.generic { iterator_types = ["parallel", "parallel"], indexing_maps = [...]} ins(....) outs(%empty : tensor<?x?xf32>) {...} ``` where the dependency is now broken. This has advantages in terms of ability to DCE code and also helps fusion. This pattern though was added before `DestinationPassingStyleOpInterface`. One thing the above pattern does is break destination passing style. A better approach is to no replace `outs` with `tensor.empty()`, instead restrict this pattern to do the following. If `%0` instead of "some operation" is "some destination passing style operation) ``` %0 = <some destination passing op> init(%dest : tensor<?x?xf32>) ... %1 = linalg.generic { iterator_types = ["parallel", "parallel"], indexing_maps = [...]} ins(....) outs(%0 : tensor<?x?xf32>) {...} ``` and `%dest` is the operand that is "tied" to `%0`, the pattern above should convert this to ``` %0 = <some destination passing op> init(%dest : tensor<?x?xf32>) ... %1 = linalg.generic { iterator_types = ["parallel", "parallel"], indexing_maps = [...]} ins(....) outs(%dest : tensor<?x?xf32>) {...} ``` This still breaks the false `%0` -> `%1` dependency. It does introduce a `%dest` to `%1` dependency, but if that is also defined by a destination passing style op, then the pattern can further apply, propagating as far up the use-def chain as possible.
joker-eph commented 11 months ago

This pattern though was added before DestinationPassingStyleOpInterface. One thing the above pattern does is break destination passing style

Do you mean "this does not preserve potential bufferization hints"? Because I don't see anything "broken" with respect to "destination passing style" itself (which definition is unrelated to bufferization in nature).

the pattern above should convert this to

%0 = <some destination passing op> init(%dest : tensor<?x?xf32>) ...
%1 = linalg.generic {
    iterator_types = ["parallel", "parallel"],  indexing_maps = [...]}
    ins(....) outs(%dest : tensor<?x?xf32>) {...}

This is doable, but it seems like this is a very specific pass with some very specific goals in mind (prepare-for-bufferization?): it does not seem like a generic canonicalization or optimization in itself.

MaheshRavishankar commented 11 months ago

This pattern though was added before DestinationPassingStyleOpInterface. One thing the above pattern does is break destination passing style

Do you mean "this does not preserve potential bufferization hints"? Because I don't see anything "broken" with respect to "destination passing style" itself (which definition is unrelated to bufferization in nature).

the pattern above should convert this to

%0 = <some destination passing op> init(%dest : tensor<?x?xf32>) ...
%1 = linalg.generic {
    iterator_types = ["parallel", "parallel"],  indexing_maps = [...]}
    ins(....) outs(%dest : tensor<?x?xf32>) {...}

This is doable, but it seems like this is a very specific pass with some very specific goals in mind (prepare-for-bufferization?): it does not seem like a generic canonicalization or optimization in itself.

It isnt a canonicalization, but is useful to do this to break the false dependency between %0 and %1... then %0 can potentially be DCE-ed or fused with other operations. This why this pattern was added in the first place. It is a very specific pattern, so thats why it is run only during the Linalg elementwise fusion transformation.

MaheshRavishankar commented 11 months ago

This pattern though was added before DestinationPassingStyleOpInterface. One thing the above pattern does is break destination passing style

Do you mean "this does not preserve potential bufferization hints"? Because I don't see anything "broken" with respect to "destination passing style" itself (which definition is unrelated to bufferization in nature).

Yeah, its not necessarily broken, but if you get a program which is in "destination passing style", it is probably better to try to preserve it. There are other ways these hints might not actually manifest after bufferization. There is no gaurantee, but its more about best effort.

joker-eph commented 11 months ago

if you get a program which is in "destination passing style",

What do you mean by "destination passing style": you seems to have some very specific expectations which I don't grasp right now.

To me "destination passing style" is a property of the operation: it can't be in any other form actually. This is fixed by the design of the operation and intrinsic to its behavior (you can't make llvm.insert_element to be "not in DPS").

MaheshRavishankar commented 11 months ago

if you get a program which is in "destination passing style",

What do you mean by "destination passing style": you seems to have some very specific expectations which I don't grasp right now.

To me "destination passing style" is a property of the operation: it can't be in any other form actually. This is fixed by the design of the operation and intrinsic to its behavior (you can't make llvm.insert_element to be "not in DPS").

I think it can also a property of a program, but we dont model that today in MLIR. Take this program

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.generic .... ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}

This function is now written in a way that %arg2 is used as a destination for all the results. This has all the information to allow bufferization to do

func.func @dps(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref<?x?xf32>)  {
  linalg.fill .... outs(%arg2 : memref<?x?xf32>)
  linalg.generic .... ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
}

There is one missing piece of information in the input IR that tells bufferization that in tensor representation %arg2 can be used for the result. Upstream does this a certain way, IREE does this a certain way.

This is a slight tangent though. The pattern in question here does not need to be as aggressive as it is today. It can still achieve its original objectives by being more restrictive, and just break the dependence from destination passing style... There might be a space for a pattern that is more aggressive (as this one), but something that respects destination passing style would be a useful thing to have.

okkwon commented 11 months ago

Thanks @MaheshRavishankar for filing an issue with more explanation!

As commented (https://github.com/llvm/llvm-project/pull/73572#issuecomment-1831202864), I will try an explicit op + memref, which might not need to touch the current code and may be able to meet my need.

joker-eph commented 11 months ago

I think it can also a property of a program, but we dont model that today in MLIR. ...

I actually can't figure out a clear definition of what you have in mind from what you're describing: it's not clear to me what the properties of the "program" would be.

There might be a space for a pattern that is more aggressive (as this one), but something that respects destination passing style would be a useful thing to have.

I would keep our terminology consistent: in MLIR "destination passing style" has a clear definition that does not reflect what you're describing here. If you feel the pattern is doing something wrong, please try to describe it from first principles.

MaheshRavishankar commented 11 months ago

I think it can also a property of a program, but we dont model that today in MLIR. ...

I actually can't figure out a clear definition of what you have in mind from what you're describing: it's not clear to me what the properties of the "program" would be.

There might be a space for a pattern that is more aggressive (as this one), but something that respects destination passing style would be a useful thing to have.

I would keep our terminology consistent: in MLIR "destination passing style" has a clear definition that does not reflect what you're describing here. If you feel the pattern is doing something wrong, please try to describe it from first principles.

I find this surprising. I have explained through examples what I mean. There is a lot of literature w.r.t destination passing style, only a subset of which is defined in MLIR. Also this is a very simple pattern, that could be equally effective without breaking destination passing style programs. What exactly is the objection here? Are you saying that the pattern as exists is correct and nothing needs to be done?

joker-eph commented 11 months ago

without breaking destination passing style programs.

Still dunno what this means :)

What exactly is the objection here? Are you saying that the pattern as exists is correct and nothing needs to be done?

From what I can tell, and without a clear definition of what a "destination passing style program" means in MLIR, then yes there is nothing to be changed here: we can't operate based on downstream mental model of "programs" and various people's way of thinking about their lowering pipelines and the coupling across layers they bring with it. (DPS in MLIR is meant to be kept unrelated to bufferization right now).

joker-eph commented 11 months ago

We can expand on your example:

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.generic .... ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}

What is the property of this program that make it qualify for DPS? Would this program also "be in DPS"?

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.generic .... ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}

What about:

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %2 = linalg.generic .... ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %2 : tensor<?x?xf32>
}

Does it have to with values having a single use? But then:

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %empty = tensor.empty_with_shape(%arg2) : tensor<?x?xf32>
  %empty2 = tensor.empty_with_shape(%arg2) : tensor<?x?xf32>
  %empty3 = tensor.empty_with_shape(%arg2) : tensor<?x?xf32>
  %0 = linalg.fill .... outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.fill .... outs(%empty2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %2 = linalg.generic .... ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty3 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %2 : tensor<?x?xf32>
}

But this one somehow does not satisfy you, and I don't know what makes one "DPS" and not the other.

MaheshRavishankar commented 11 months ago

without breaking destination passing style programs.

Still dunno what this means :)

What exactly is the objection here? Are you saying that the pattern as exists is correct and nothing needs to be done?

From what I can tell, and without a clear definition of what a "destination passing style program" means in MLIR, then yes there is nothing to be changed here: we can't operate based on downstream mental model of "programs" and various people's way of thinking about their lowering pipelines and the coupling across layers they bring with it. (DPS in MLIR is meant to be kept unrelated to bufferization right now).

Well, just for the record I wrote this pattern waay back when :) . I am actually not considering anything w.r.t lowering pipelines. FWIW this pattern is used as is in IREE, and the change I am proposing here would also work as is in IREE, but that is not the point here. I'd also say, IREE actually does not need this change, and I think this change will make it easier for people using MLIR. (Although I hope downstream experiences where actually everything is connected end-to-end will inform development in MLIR). I'd like to understand more what you mean by DPS in MLIR is unrelated to bufferization. It very much is in my mind, and was born out of bufferization, and DPS very much helps bufferization. Id be curious what other aspects does DPS affect.

I am happy to chat over video if a high bandwidth connection helps.

MaheshRavishankar commented 11 months ago

We can expand on your example:

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.generic .... ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}

What is the property of this program that make it qualify for DPS? Would this program also "be in DPS"?

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.generic .... ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %1 : tensor<?x?xf32>
}

What about:

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.fill .... outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %2 = linalg.generic .... ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %2 : tensor<?x?xf32>
}

Does it have to with values having a single use? But then:

func.func @dps(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
  %empty = tensor.empty_with_shape(%arg2) : tensor<?x?xf32>
  %empty2 = tensor.empty_with_shape(%arg2) : tensor<?x?xf32>
  %empty3 = tensor.empty_with_shape(%arg2) : tensor<?x?xf32>
  %0 = linalg.fill .... outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
  %1 = linalg.fill .... outs(%empty2 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %2 = linalg.generic .... ins(%0, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%empty3 : tensor<?x?xf32>) -> tensor<?x?xf32>
  return %2 : tensor<?x?xf32>
}

But this one somehow does not satisfy you, and I don't know what makes one "DPS" and not the other.

Might be going off-track here... this is not a simple topic. Kind of different point though IMO... The full implications of what it means to be DPS is probably not what is to be considered here. I am trying to say this pattern can be less aggressive and still be as effective without introducing unnecessary tensor.empty. If I rephrase the problem as "avoid creating tensor.empty when unnecessary" will that satisfy you?

joker-eph commented 11 months ago

What does "when unnecessary" means?

If I take your examples, you acknowledge that this is enabling DCE, CSE, etc. : we're getting close to a canonicalization!