microsoft / DirectML

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.
MIT License
2.24k stars 299 forks source link

How to use DirectML with transformers TrainingArguments #652

Open osadchi opened 1 month ago

osadchi commented 1 month ago

Thank you, for your work!! I have a repo for training llama models, but it's for cuda. How would I fix it to use DirectML with multi GPU instead of CUDA (I have AMD ROCm and NVIDIA CUDA cards)?

https://github.com/RuslanPeresy/gptchain/blob/main/train.py

There's CUDA everywhere :C

fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),

Thank you Again!