iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.52k stars 557 forks source link

[RFC] Supporting Armv9 Scalable Matrix Extension (SME) Streaming SVE (SSVE) mode in MLIR / IREE #13556

Closed c-rhodes closed 1 year ago

c-rhodes commented 1 year ago

The Armv9 Scalable Matrix Extension (SME) defines a new "Streaming SVE" (SSVE) execution mode [1].

The purpose of this RFC is to evaluate what is required to support this mode in MLIR / IREE and propose next steps to address this. The focus of this RFC is SSVE only, not full SME support, but this is an important first step towards this.

SSVE is controlled by a processor state bit called PSTATE.SM (see section B1.1 in [1]). From [2]:

At any given point in time, the processor is either in streaming mode (PSTATE.SM==1) or in non-streaming mode (PSTATE.SM==0), also referred to as "normal" mode. There is an instruction called SMSTART to enter streaming mode and an instruction called SMSTOP to return to non-streaming mode.

From [3]:

When changing PSTATE.SM the execution of FP/vector operations may be transferred to another processing element. This has three important implications:

  • The runtime SVE vector length may change.
  • The contents of FP/AdvSIMD/SVE registers are zeroed.
  • The set of allowable instructions changes.

To support this feature there are two key problems to address:

This RFC is concerned with the former. The latter entails support for scalable vector enablement in MLIR / IREE which is outside of the scope of this RFC, but my colleague Andrzej Warzynski from Arm and Diego Caballero from Google posted an RFC on scalable vectorisation in Linalg last week that's relevant to this and worth a read.

In the LLVM backend PSTATE.SM is managed at the function boundary with function and callsite attributes. These were added in the enablement of the AArch64 SME Arm C Language Extensions (ACLE) [4]. The LLVM AArch64 SME docs [3] go into more detail on this.

To support SSVE in MLIR / IREE we can leverage these existing attributes to manage PSTATE.SM. [5] provides a list of attributes, for the purposes of this RFC only the SSVE attributes are relevant, these are:

The next section goes into detail for each of them and provides examples.

aarch64_pstate_sm_enabled

Calls to functions with this attribute will be wrapped with smstart sm / smstop sm. For example, take the following LLVM IR:

define void @streaming_callee() #0 {
  ret void
}

define void @normal_caller() {
  call void @streaming_callee()
  ret void
}

attributes #0 = { "aarch64_pstate_sm_enabled" }

Compiled with:

  llc -mtriple=aarch64 -mattr=+sve,+sme

Produces the following asm:

streaming_callee:                       // @streaming_callee
        ret
normal_caller:                          // @normal_caller
        stp     d15, d14, [sp, #-80]!           // 16-byte Folded Spill
        stp     d13, d12, [sp, #16]             // 16-byte Folded Spill
        stp     d11, d10, [sp, #32]             // 16-byte Folded Spill
        stp     d9, d8, [sp, #48]               // 16-byte Folded Spill
        str     x30, [sp, #64]                  // 8-byte Folded Spill
        smstart sm
        bl      streaming_callee
        smstop  sm
        ldp     d9, d8, [sp, #48]               // 16-byte Folded Reload
        ldp     d11, d10, [sp, #32]             // 16-byte Folded Reload
        ldp     d13, d12, [sp, #16]             // 16-byte Folded Reload
        ldr     x30, [sp, #64]                  // 8-byte Folded Reload
        ldp     d15, d14, [sp], #80             // 16-byte Folded Reload
        ret

Streaming mode is enabled before the call to the streaming function and disabled after.

aarch64_pstate_sm_body

The backend will insert smstart sm / smstop sm into the prologue/epilogue for functions with this attribute.

For example, using the IR from the previous example with this attribute:

define void @streaming_callee() #0 {
  ret void
}

define void @normal_caller() {
  call void @streaming_callee()
  ret void
}

attributes #0 = { "aarch64_pstate_sm_body" }

Produces the following asm:

streaming_callee:                       // @streaming_callee
        stp     d15, d14, [sp, #-64]!           // 16-byte Folded Spill
        stp     d13, d12, [sp, #16]             // 16-byte Folded Spill
        stp     d11, d10, [sp, #32]             // 16-byte Folded Spill
        stp     d9, d8, [sp, #48]               // 16-byte Folded Spill
        smstart sm
        smstop  sm
        ldp     d9, d8, [sp, #48]               // 16-byte Folded Reload
        ldp     d11, d10, [sp, #32]             // 16-byte Folded Reload
        ldp     d13, d12, [sp, #16]             // 16-byte Folded Reload
        ldp     d15, d14, [sp], #64             // 16-byte Folded Reload
        ret
normal_caller:                          // @normal_caller
        str     x30, [sp, #-16]!                // 8-byte Folded Spill
        bl      streaming_callee
        ldr     x30, [sp], #16                  // 8-byte Folded Reload
        ret

Notice smstart sm / smstop sm are emitted inside the streaming function, rather than around the call as with the previous attribute.

The ACLE [6] provides more detail on this:

This choice is internal to the function definition. It is not visible to callers and so it can be changed without affecting the function’s binary interface. (In other words, it can be changed without requiring all callers to be recompiled.) ... This approach can be useful when implementing existing APIs, including when overriding virtual functions. It allows the use of SME to be an internal implementation detail.

Since this is an internal attribute that's not visible to the caller, the compiler must disable streaming mode if enabled before calling a locally streaming function.

For example, the following IR:

declare void @locally_streaming_callee() #0

define void @locally_streaming_caller() #0 {
  call void @locally_streaming_callee()
  ret void
}

attributes #0 = { "aarch64_pstate_sm_body" }

Produces this asm:

locally_streaming_caller:               // @locally_streaming_caller
        stp     d15, d14, [sp, #-80]!           // 16-byte Folded Spill
        stp     d13, d12, [sp, #16]             // 16-byte Folded Spill
        stp     d11, d10, [sp, #32]             // 16-byte Folded Spill
        stp     d9, d8, [sp, #48]               // 16-byte Folded Spill
        str     x30, [sp, #64]                  // 8-byte Folded Spill
        smstart sm
        smstop  sm
        bl      locally_streaming_callee
        smstart sm
        smstop  sm
        ldp     d9, d8, [sp, #48]               // 16-byte Folded Reload
        ldp     d11, d10, [sp, #32]             // 16-byte Folded Reload
        ldp     d13, d12, [sp, #16]             // 16-byte Folded Reload
        ldr     x30, [sp, #64]                  // 8-byte Folded Reload
        ldp     d15, d14, [sp], #80             // 16-byte Folded Reload
        ret

The compiler has to emit unnecessary smstart/smstop instructions that in turn cause expensive spills / fills, since the CPU zeroes the FP/vector registers when changing PSTATE.SM.

aarch64_pstate_sm_compatible

Functions with this attribute stick to the streaming subset of SVE and are callable from both normal and streaming-mode. What this means in practice is the function is restricted to set of instructions that are legal in either mode. This is useful for compatibility in things like vector versions of math.h routines such as tanh.

Initial plan for SSVE support in MLIR/IREE

Streaming mode can be targeted from MLIR at the function boundary by specifying these attributes via the passthrough mechanism [7]. This works today with no changes.

Example MLIR:

// A streaming-mode function.
func.func private @streaming_callee() attributes {passthrough = ["aarch64_pstate_sm_enabled"]}
func.func private @normal_caller() {
  func.call @streaming_callee() : () -> ()
  return
}

Compile:

mlir-opt --convert-func-to-llvm streaming_sve.mlir | mlir-translate --mlir-to-llvmir | llc -mattr=+sve,+sme
normal_caller:
        stp     d15, d14, [sp, #-80]!           // 16-byte Folded Spill
        stp     d13, d12, [sp, #16]             // 16-byte Folded Spill
        stp     d11, d10, [sp, #32]             // 16-byte Folded Spill
        stp     d9, d8, [sp, #48]               // 16-byte Folded Spill
        str     x30, [sp, #64]                  // 8-byte Folded Spill
        smstart sm
        bl      streaming_callee
        smstop  sm
        ldp     d9, d8, [sp, #48]               // 16-byte Folded Reload
        ldp     d11, d10, [sp, #32]             // 16-byte Folded Reload
        ldp     d13, d12, [sp, #16]             // 16-byte Folded Reload
        ldr     x30, [sp, #64]                  // 8-byte Folded Reload
        ldp     d15, d14, [sp], #80             // 16-byte Folded Reload
        ret

By leveraging existing attributes the backend is responsible for ensuring the code generator produces instructions that are legal in streaming mode.

To enable streaming mode we need a mechanism to add these attributes to functions. In IREE the attributes could initially be applied to all dispatch functions, and later be applied based on a more fine-grained heuristic. This would be predicated on --iree-llvm-target-cpu-features=+sve,+sme and could be done during LLVM lowering [8].

Another option is to implement this as a pass in either MLIR or IREE. In IREE the iree-llvmcpu-lower-executable-target pass adds a translation_info attribute to each dispatch function that describes the lowering pipeline to use. The pipelines are defined in [9] and are selected based on the ops in the dispatch. For example, a dispatch with a Linalg convolution op will use the DispatchLoweringPassPipeline::CPUConvTileAndDecomposeExpert pipeline. This heuristic could be leveraged to selectively enable SSVE by integrating the MLIR pass into the relevant lowering configuration pipelines [10].

One other thing to consider is the distribution of workgroups in dispatch regions, for SME it might be best for the region affinity and distribution to be on a single thread or serialized. The --iree-codegen-llvm-disable-distribution flag could be used for this and enabled by default for SSVE.

Which attribute to use?

The ABI [11] states:

It is the caller's responsibility to ensure that PSTATE.SM has a valid value on entry to a callee.

In the context of IREE where the runtime calls dispatch functions, this suggests it is the responsibility of the runtime. However, the runtime doesn't know the details of dispatches such that it could emit these instructions. If a pass is introduced that adds the aarch64_pstate_sm_enabled attribute to dispatch functions it's effectively changing the function's ABI forcing the caller (IREE runtime) to be responsible for managing PSTATE.SM before entry/exit. Given the current limitation of the runtime this attribute cannot be used.

Streaming-mode must instead be managed on function entry/exit. The aarch64_pstate_sm_body attribute can be used for this. There may be scope for adding backend intrinsics for smstart / smstop in the future that provide better granularity (i.e. not restricted to function boundary) for managing PSTATE.SM, but it is worth emphasising one of the big advantages to using the existing attributes is the constraint to only use instructions that are legal in streaming mode is handled by the backend. This will help to address the second key problem mentioned earlier.

Another consideration is the runtime SVE vector length may differ in streaming mode. This is mentioned in the list of restrictions for these attributes [12]:

It is undefined behaviour to pass or return (pointers to) scalable vector objects to/from functions which may use a different SVE vector length. This includes functions with a non-streaming interface, but marked with aarch64_pstate_sm_body.

[13] contains further info on this:

However, it is unlikely for this to happen without user intervention, because arm_locally_streaming functions cannot take or return vector-length-dependent values.

The same restriction should be present in the pass.

The streaming-compatible interface doesn't seem particularly useful given the management of streaming-mode will be done in the MLIR/IREE compiler vs by the user in C code. We can make full use of the features in either mode.

Proposed next steps

I've shared a patch with a pass in IREE that adds the aarch64_pstate_sm_body to functions. The pass is enabled for AArch64 when SVE(2) and SME are enabled for the following lowering configurations:

These configurations were chosen simply because they're used in one of our pipelines.

In the patch I shared I've added the pass to IREE rather than MLIR because it uses the internal backend attribute to enable SSVE to avoid adding support to the IREE runtime.

Any questions or suggestions would be appreciated. Thanks for reading!

Links

[1] https://developer.arm.com/documentation/ddi0616/latest [2] https://arm-software.github.io/acle/main/acle.html#introduction-to-streaming-and-non-streaming-mode [3] https://llvm.org/docs/AArch64SME.html#handling-pstate-sm [4] https://arm-software.github.io/acle/main/acle.html#sme-language-extensions-and-intrinsics [5] https://llvm.org/docs/AArch64SME.html#introduction [6] https://arm-software.github.io/acle/main/acle.html#changing-streaming-mode-locally [7] https://mlir.llvm.org/docs/Dialects/LLVM/#attribute-pass-through [8] https://github.com/openxla/iree/blob/4952c50057a78ef86fc92c17742b0e2674df5964/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp#L231-L245 [9] https://github.com/openxla/iree/blob/4952c50057a78ef86fc92c17742b0e2674df5964/compiler/src/iree/compiler/Codegen/Dialect/LoweringConfig.td [10] https://github.com/openxla/iree/blob/4952c50057a78ef86fc92c17742b0e2674df5964/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerExecutableTarget.cpp#L187 [11] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces [12] https://llvm.org/docs/AArch64SME.html#restrictions-on-attributes [13] https://llvm.org/docs/AArch64SME.html#functions-with-attribute-arm-locally-streaming

dcaballe commented 1 year ago

Thanks a lot for the RFC! I just wanted to post a quick comment to let you know that I'm processing the details and this hasn't fallen through the cracks :) I wonder, though, if this is something that we should move to the MLIR discourse. I have the impression that we want to have single way to represent SSME in MLIR, beyond IREE, and then take that representation and implement it within IREE. WDYT?

c-rhodes commented 1 year ago

Thanks a lot for the RFC! I just wanted to post a quick comment to let you know that I'm processing the details and this hasn't fallen through the cracks :) I wonder, though, if this is something that we should move to the MLIR discourse. I have the impression that we want to have single way to represent SSME in MLIR, beyond IREE, and then take that representation and implement it within IREE. WDYT?

Initially that was the plan but I was focused on enabling SSVE in IREE and that influenced the design, I wasn't sure other users of MLIR had the same constraints. To represent this in MLIR I would propose adding a pass that adds the aarch64_pstate_sm_enabled attribute to all functions, and later come up with a heuristic that selectively enables SSVE. I'd be happy to remove IREE specific stuff from this RFC and post on LLVM discourse.

c-rhodes commented 1 year ago

Posted on LLVM Discourse: https://discourse.llvm.org/t/rfc-supporting-armv9-scalable-matrix-extension-sme-streaming-sve-ssve-mode-in-mlir/70678

banach-space commented 1 year ago

That's a very detailed and a well crafted RFC, thank you @c-rhodes !

To represent this in MLIR I would propose adding a pass that adds the aarch64_pstate_sm_enabled attribute to all functions

Wouldn't aarch64_pstate_sm_body work equally well in MLIR? I know that in IREE that's effectively the only option (aarch64_pstate_sm_enabled changes the ABI, and we want to avoid that). Basically, aarch64_pstate_sm_body feels like the "safe option" for now, regardless of whether focusing on IREE or MLIR. And we can always revisit later once this choice is too limiting. So, unless I am missing something, we could land this in MLIR as is :)

dcaballe commented 1 year ago

Following up to my reply to the MLIR post…

Another option is to implement this as a pass in either MLIR or IREE

I think that at least the utility that annotates the functions with the attributes should be implemented and tested in MLIR. Also the lowering to LLVM. That would help the community align with the proposed SSME representation and build on top of that without reinventing the wheel. We don’t have to restrict it to IREE’s specific use case. The utility could take a function and a “streaming mode” and return the annotated function which would eventually be lowered to LLVM.

This heuristic could be leveraged to selectively enable SSVE by integrating the MLIR pass into the relevant lowering configuration pipelines +1. If you initially don’t feel very confident about these heuristics, this is something we can incrementally build within IREE and once it reaches certain level of maturity we could reconsider if it has value for MLIR.

One other thing to consider is the distribution of workgroups in dispatch regions, for SME it might be best for the region affinity and distribution to be on a single thread or serialized. The --iree-codegen-llvm-disable-distribution flag could be used for this and enabled by default for SSVE.

I think that flag is there mostly for testing/debugging purposes. However, we have target specific distribution (and vectorization/unrolling) configurations so it would be a matter of setting the tile sizes for distribution to zero for SSVE. Similar to what the flag does, actually.

In general the approach makes sense to me! Hopefully that helps! Let us know if we can help with anything else!

c-rhodes commented 1 year ago

That's a very detailed and a well crafted RFC, thank you @c-rhodes !

To represent this in MLIR I would propose adding a pass that adds the aarch64_pstate_sm_enabled attribute to all functions

Wouldn't aarch64_pstate_sm_body work equally well in MLIR? I know that in IREE that's effectively the only option (aarch64_pstate_sm_enabled changes the ABI, and we want to avoid that). Basically, aarch64_pstate_sm_body feels like the "safe option" for now, regardless of whether focusing on IREE or MLIR. And we can always revisit later once this choice is too limiting. So, unless I am missing something, we could land this in MLIR as is :)

Thanks Andrzej. Yes it would work in MLIR but there are consequences to PSTATE.SM not being part of the interface highlighted above. The aarch64_pstate_sm_enabled attribute is more suitable for calls between streaming functions, here an example comparing the codegen for each attribute for this:

aarch64_pstate_sm_enabled

define void @streaming_func1() #0 {
  ret void
}

define void @streaming_func2() #0 {
  call void @streaming_func1()
  ret void
}

attributes #0 = { "aarch64_pstate_sm_enabled" }

compile

; llc -mtriple=aarch64-linux-gnu -mattr=+sme

streaming_func1:                        // @streaming_func1
        ret
streaming_func2:                        // @streaming_func2
        str     x30, [sp, #-16]!                // 8-byte Folded Spill
        bl      streaming_func1
        ldr     x30, [sp], #16                  // 8-byte Folded Reload
        ret

aarch64_pstate_sm_body

define void @streaming_func1() #0 {
  ret void
}

define void @streaming_func2() #0 {
  call void @streaming_func1()
  ret void
}

attributes #0 = { "aarch64_pstate_sm_body" }

compile

; llc -mtriple=aarch64-linux-gnu -mattr=+sme

streaming_func1:                        // @streaming_func1
        stp     d15, d14, [sp, #-64]!           // 16-byte Folded Spill
        stp     d13, d12, [sp, #16]             // 16-byte Folded Spill
        stp     d11, d10, [sp, #32]             // 16-byte Folded Spill
        stp     d9, d8, [sp, #48]               // 16-byte Folded Spill
        smstart sm
        smstop  sm
        ldp     d9, d8, [sp, #48]               // 16-byte Folded Reload
        ldp     d11, d10, [sp, #32]             // 16-byte Folded Reload
        ldp     d13, d12, [sp, #16]             // 16-byte Folded Reload
        ldp     d15, d14, [sp], #64             // 16-byte Folded Reload
        ret
streaming_func2:                        // @streaming_func2
        stp     d15, d14, [sp, #-80]!           // 16-byte Folded Spill
        stp     d13, d12, [sp, #16]             // 16-byte Folded Spill
        stp     d11, d10, [sp, #32]             // 16-byte Folded Spill
        stp     d9, d8, [sp, #48]               // 16-byte Folded Spill
        str     x30, [sp, #64]                  // 8-byte Folded Spill
        smstart sm
        smstop  sm
        bl      streaming_func1
        smstart sm
        smstop  sm
        ldp     d9, d8, [sp, #48]               // 16-byte Folded Reload
        ldp     d11, d10, [sp, #32]             // 16-byte Folded Reload
        ldp     d13, d12, [sp, #16]             // 16-byte Folded Reload
        ldr     x30, [sp, #64]                  // 8-byte Folded Reload
        ldp     d15, d14, [sp], #80             // 16-byte Folded Reload
        ret

In IREE there's no option but to use the internal attribute but I don't know if that's the case for MLIR.

c-rhodes commented 1 year ago

Following up to my reply to the MLIR post…

Another option is to implement this as a pass in either MLIR or IREE

I think that at least the utility that annotates the functions with the attributes should be implemented and tested in MLIR. Also the lowering to LLVM. That would help the community align with the proposed SSME representation and build on top of that without reinventing the wheel. We don’t have to restrict it to IREE’s specific use case. The utility could take a function and a “streaming mode” and return the annotated function which would eventually be lowered to LLVM.

Thanks for the suggestion. I look at moving the pass to MLIR with an option to control which attribute is used. There's already a test covering the lowering to LLVM for the aarch64_pstate_sm_enabled attribute I added recently: https://github.com/llvm/llvm-project/commit/c8d1388e6c8bd57299d5801f170719218f735c4c

Although it has now been made clear the passthrough mechanism is for prototyping so this will need updating.

This heuristic could be leveraged to selectively enable SSVE by integrating the MLIR pass into the relevant lowering configuration pipelines

+1. If you initially don’t feel very confident about these heuristics, this is something we can incrementally build within IREE and once it reaches certain level of maturity we could reconsider if it has value for MLIR.

I agree, this is inline with our thinking.

One other thing to consider is the distribution of workgroups in dispatch regions, for SME it might be best for the region affinity and distribution to be on a single thread or serialized. The --iree-codegen-llvm-disable-distribution flag could be used for this and enabled by default for SSVE.

I think that flag is there mostly for testing/debugging purposes. However, we have target specific distribution (and vectorization/unrolling) configurations so it would be a matter of setting the tile sizes for distribution to zero for SSVE. Similar to what the flag does, actually.

Thanks for pointing that out, I don't think this is an issue for us yet but that will be useful for when it is.

In general the approach makes sense to me! Hopefully that helps! Let us know if we can help with anything else!

Thanks again Diego!

banach-space commented 1 year ago

Yes it would work in MLIR but there are consequences to PSTATE.SM not being part of the interface highlighted above. The aarch64_pstate_sm_enabled attribute is more suitable for calls between streaming functions, here an example comparing the codegen for each attribute for this:

Thank you for your thorough explanation and the CE link - cool to see this working upstream!

I look at moving the pass to MLIR with an option to control which attribute is used.

+1 to adding an option. In principle, it sounds like aarch64_pstate_sm_body is both the easier and pretty much always correct option. But for completeness, we will support both:

We are yet to understand the most efficient way of using the latter, but that's for later (Rome was not built in a day!).

Thanks for driving this!

c-rhodes commented 1 year ago

Closing this now #13558 has been merged.