intel / ideep

Intel® Optimization for Chainer*, a Chainer module providing numpy like API and DNN acceleration using MKL-DNN.
MIT License
165 stars 90 forks source link

cpu: aarch64: Enable matmul bf16f32 format desc #343

Open aditew01 opened 3 weeks ago

aditew01 commented 3 weeks ago

This will allow generating desc for matmul of format src:bf16; wei:bf16; dst:f32 instead of reordering dst to bf16 and back to f32. These kernels are directly used by cpu_flash_attention (sdpa).

aditew01 commented 3 weeks ago

cc: @yanbing-j can you please help with the review ?

yanbing-j commented 2 weeks ago

Request @jgong5 for review.

aditew01 commented 2 weeks ago

Can we override things from the caller side?

Maybe we can, but if it's for this specific case, that'll imply there will be two sets of logic for bf16-f32 and f16 (or the cases which are still there) .
Unless we want to refactor the code (we can open a issue and look at the problem broadly) , it's better to have specific logic in do_prepare call ?

jgong5 commented 1 week ago

Can we override things from the caller side?

Maybe we can, but if it's for this specific case, that'll imply there will be two sets of logic for bf16-f32 and f16 (or the cases which are still there) . Unless we want to refactor the code (we can open a issue and look at the problem broadly) , it's better to have specific logic in do_prepare call ?

I guess I have mentioned my major concern of inconsistency of the semantics with the change between aarch64 and x86. I'm list the differences with the tables below: aarch64 Source Type (src) Destination Type (dst) Destination Data Type (dst_datatype)
bf16 fp32 fp32
bf16 not fp32 bf16
fp16 any fp16
other any fp32
x86 Source Type (src) Destination Type (dst) Destination Data Type (dst_data_type)
bf16 any bf16
fp16 any fp16
other any fp32

How can we make them aligned?

aditew01 commented 1 week ago

Apology if this is repetitive. That will be tricky to test as well, right? This is a very specific change and if we were to align x86 in a similar fashion, we may loose on perf if there's no specific kernel of the format available. If it's for the semantics to align, I'm not sure where we can do that. Even if we push this code higher (change to caller) , eg: https://github.com/intel/ideep/blob/c54a3ede72ae05b483ee81353e94227c3057424d/include/ideep/operators/matmul.hpp#L289 the logic will be similar right or do you think that's a better place to update design-wise? cc: @jgong5

jgong5 commented 1 week ago

Apology if this is repetitive. That will be tricky to test as well, right? This is a very specific change and if we were to align x86 in a similar fashion, we may loose on perf if there's no specific kernel of the format available. If it's for the semantics to align, I'm not sure where we can do that. Even if we push this code higher (change to caller) , eg:

https://github.com/intel/ideep/blob/c54a3ede72ae05b483ee81353e94227c3057424d/include/ideep/operators/matmul.hpp#L289

the logic will be similar right or do you think that's a better place to update design-wise? cc: @jgong5

Frankly speaking, I don't quite understand why the datatype semantics has to be different across different CPU archs. It is not just about the questions of runtime efficiency but also about the precisions and accuracy that are visible to users. Can you explain why we have to make things different here?

aditew01 commented 1 week ago

Apology if this is repetitive. That will be tricky to test as well, right? This is a very specific change and if we were to align x86 in a similar fashion, we may loose on perf if there's no specific kernel of the format available. If it's for the semantics to align, I'm not sure where we can do that. Even if we push this code higher (change to caller) , eg: https://github.com/intel/ideep/blob/c54a3ede72ae05b483ee81353e94227c3057424d/include/ideep/operators/matmul.hpp#L289

the logic will be similar right or do you think that's a better place to update design-wise? cc: @jgong5

Frankly speaking, I don't quite understand why the datatype semantics has to be different across different CPU archs. It is not just about the questions of runtime efficiency but also about the precisions and accuracy that are visible to users. Can you explain why we have to make things different here?

The scaled-dot-product-attention op implemented in PyTorch calls cpublas::gemm. For dtype::bf16 , the gemm operator takes input mat in bf16 and returns a fp32 ref. Pytorch code: https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L750

For x86, it's calling the underlying MKL kernel: mkl_gemm_bf16bf16f32 https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/CPUBlas.cpp#L420

The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense.

jgong5 commented 1 week ago

The scaled-dot-product-attention op implemented in PyTorch calls cpublas::gemm. For dtype::bf16 , the gemm operator takes input mat in bf16 and returns a fp32 ref. Pytorch code: https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L750

For x86, it's calling the underlying MKL kernel: mkl_gemm_bf16bf16f32 https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/CPUBlas.cpp#L420

The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense.

Yes, that makes sense. It means the precisions do keep the same between aarch64 and x86 for SDPA. And I also understand that we need the ideep API to handle two scenarios: one that returns the same data type as the input and one that returns fp32. We can extend the semantics of the API to support fp32 but I want the ideep API designed in a way to behave the same between aarch64 and x86, i.e., same data type mapping between input and output. I see two options (please comment if you have other ideas):

  1. Let the caller to specify that it wants fp32 output instead of following the input data type. It can keep the original caller unchanged and not BC breaking. To me, it is a good choice. How are you going to invoke the onednn from SDPA? Can you specify that you want fp32 output from the caller?
  2. Change the behavior of the API to return fp32 by default on some reasonable conditions. It is BC breaking. We need to double check the existing caller to see if it would break changes.
aditew01 commented 1 week ago

The scaled-dot-product-attention op implemented in PyTorch calls cpublas::gemm. For dtype::bf16 , the gemm operator takes input mat in bf16 and returns a fp32 ref. Pytorch code: https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L750 For x86, it's calling the underlying MKL kernel: mkl_gemm_bf16bf16f32 https://github.com/pytorch/pytorch/blob/f0f61443819ce19a16c8eef3a45a92e51dcfc17e/aten/src/ATen/native/CPUBlas.cpp#L420 The logic implemented here enables oneDNN to pick the ACL kernel. I hope this makes sense.

Yes, that makes sense. It means the precisions do keep the same between aarch64 and x86 for SDPA. And I also understand that we need the ideep API to handle two scenarios: one that returns the same data type as the input and one that returns fp32. We can extend the semantics of the API to support fp32 but I want the ideep API designed in a way to behave the same between aarch64 and x86, i.e., same data type mapping between input and output. I see two options (please comment if you have other ideas):

  1. Let the caller to specify that it wants fp32 output instead of following the input data type. It can keep the original caller unchanged and not BC breaking. To me, it is a good choice. How are you going to invoke the onednn from SDPA? Can you specify that you want fp32 output from the caller?
  2. Change the behavior of the API to return fp32 by default on some reasonable conditions. It is BC breaking. We need to double check the existing caller to see if it would break changes.

Thanks for the inputs. Alternatively, would it not be better to enable this for for both aarch64 and x86. I believe that'll be the best way handle this. This will align to the already existing MKL kernel which is already being called. It'll ensure the same precision for specific operators like SDPA ? https://github.com/intel/ideep/pull/343#issuecomment-2482833306

The above suggested mechanism works, but in this case the caller is the descriptors which ideep generates right? And while setting the src_ and dst_ data_type here, we enable the respective onednn kernels to make dispatch decisions based on this ? Reference: https://github.com/oneapi-src/oneDNN/blob/d94cc8d4fbed06867ed3bebb04ac91573175ebfa/src/cpu/aarch64/matmul/acl_matmul.cpp#L79

jgong5 commented 1 week ago

Thanks for the inputs. Alternatively, would it not be better to enable this for for both aarch64 and x86. I believe that'll be the best way handle this.

That sounds good to me. We'd better not break existing code though.