ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
18 stars 14 forks source link

Make rocblas_gemm_flags_fp16_alt_impl in MHA and MLP backward compatible with old PyTorch versions #74

Closed hubertlu-tw closed 2 years ago

hubertlu-tw commented 2 years ago

@hubertlu-tw Why do the MHA files not use the ROCM_BACKWARD_PASS_GUARD define?

It should use the ROCM_BACKWARD_PASS_GUARD define in MHA, too. Thanks for pointing out.

hubertlu-tw commented 2 years ago

The two scenarios have been tested locally.

  1. PyTorch with the commit (https://github.com/pytorch/pytorch/pull/71881) for rocblas_gemm_flags_fp16_alt_impl
  2. PyTorch without the commit for rocblas_gemm_flags_fp16_alt_impl as of the commit (8796aad30f314a11cfaeead80b90ce6c6fc934f8) (https://github.com/ROCmSoftwarePlatform/pytorch/tree/IFU-master-2022-03-30)