ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
19 stars 17 forks source link

Add rocblas_alt_impl falg for bwd rocblas calls in MHA #70

Closed athitten closed 2 years ago

athitten commented 2 years ago

Updated all the files in MHA with rocblas_alt_impl flag in bwd rocblas calls. Checked all the unit tests and all of them passed.

hubertlu-tw commented 2 years ago

jenkins: retest this please

hubertlu-tw commented 2 years ago

jenkins: retest this please

hubertlu-tw commented 2 years ago

The failing unit test is from distributed unit tests. The command is $ python -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py $ python compare.py The error is not introduced by the changes of this PR.

The test failed in apex-rocm-pytorch-master (with PyTorch built from the tip of ROCm master branch (4a1785)) but passed in apex-rocm-pytorch-release (rocm/pytorch:latest = rocm/pytorch:rocm5.0.1_ubuntu18.04_py3.7_pytorch_staging).