tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.42k stars 413 forks source link

DirectML support #1997

Open abflow opened 2 months ago

abflow commented 2 months ago

Feature description

DirectML (Direct Machine Learning) is a high-performance, hardware-accelerated API that can leverage the capabilities of modern GPUs and TPUs on Windows platforms. Integrating DirectML would enable Burn to run efficiently on a wider range of hardware, particularly on new Windows PCs that do not have NVIDIA GPUs.

DirectML Repository: https://github.com/microsoft/DirectML Rust Module: https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/AI/MachineLearning/DirectML/index.html

Feature motivation

Many new Windows PCs (Copilot+ PCs) do not have an NVIDIA card. Instead, they may come equipped with integrated or alternative GPUs and TPUs, such as those found in devices using Snapdragon processors. Currently, Burn leverages LibTorch, which primarily support CPU, CUDA, and Metal. By integrating DirectML, Burn would become more versatile and accessible to users on Windows platforms, thereby broadening its adoption and usability.

oleid commented 1 month ago

Considering the feature motivation: You forgot the WGPU back-end. That should leverage DirectX12 even on Snapdragon devices. The TPU should be usable for inference though.