ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
19 stars 19 forks source link

Cherry-pick b2fdf9c from upstream Apex and resolve conflicts #68

Closed jithunnair-amd closed 2 years ago

jithunnair-amd commented 2 years ago

This fixes Apex build issues with older PyTorch versions 1.8.1, 1.9 and 1.10, since they didn't have the ATen/cuda/CUDAGeneratorImpl.h file available.

15:05:31      In file included from /tmp/pip-req-build-sttueip5/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_hip.hip:15:
15:05:31      /tmp/pip-req-build-sttueip5/apex/contrib/csrc/multihead_attn/dropout_hip.h:8:10: fatal error: 'ATen/hip/HIPGeneratorImpl.h' file not found
15:05:31      #include <ATen/hip/HIPGeneratorImpl.h>
jithunnair-amd commented 2 years ago

retest it please