ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
18 stars 14 forks source link

Make rocblas_gemm_flags_fp16_alt_impl backward-compat for new naming #79

Closed hubertlu-tw closed 2 years ago

hubertlu-tw commented 2 years ago

As the PR in PyTorch upstream for rocblas_gemm_flags_fp16_alt_impl (used in Apex MHA and MLP extensions) has renamed BackwardPassGuard to ROCmBackwardPassGuard, the changes in this PR can prevent the backward-breaking issues.