microsoft / DirectML

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.
MIT License
2.15k stars 286 forks source link

torch-directml: Is autocast unavailable on DirectML devices? #454

Open lshqqytiger opened 1 year ago

lshqqytiger commented 1 year ago
Python 3.10.9 | packaged by conda-forge | (main, Jan 11 2023, 15:15:40) [MSC v.1916 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch, torch_directml
>>> a_float32 = torch.rand((8, 8), device="privateuseone:0")
>>> b_float32 = torch.rand((8, 8), device="privateuseone:0")
>>> with torch.autocast("privateuseone"):
...     e_float16 = torch.mm(a_float32, b_float32)
...
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\miniconda3\envs\webuidml\lib\site-packages\torch\amp\autocast_mode.py", line 201, in __init__
    raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
RuntimeError: User specified autocast device_type must be 'cuda' or 'cpu'
>>> torch.set_autocast_enabled(True)
>>> torch.set_autocast_gpu_dtype(torch.float16)
>>> torch.mm(a_float32, b_float32).dtype # should be torch.float16
torch.float32
NeedsMoar commented 11 months ago

It doesn't appear to be redefined, which means it runs the normal version which has a hard check for CUDA or CPU will run and fail. I'd say you probably don't need it either, since torch_directml supposedly doesn't support the half type correctly anyway. I haven't noticed any real difference between speed in fp32 vs. fp16 when fp16 should be twice as fast and very noticable.

The biggest problem is when it's in code for a web API and not special cased to not run when directml is used, so if you're using AMD you need to fix it by either deleting the autocast block or adding a condition checking for directml (or checking for CUDA) in front of the IF... Whey they made that a hard failure instead of a soft warning that autocast would be disabled I have no idea, since about 5 lines below that they just emit a warning if there's a data type mismatch with a couple of other backends with support.

It's hard to tell since a lot of the examples and some of the code seem to assume you'll be running it on a card that also has CUDA it can fall back on; I blame both GPU manufacturers for that.

torch_directml isn't very clear about which parts of torch it's overriding unless you read through the source code, but it's looking like a big part of the reason it's so slow compared to other backends right now is that thanks to the sparse examples nobody is really using the graph capture / optimization features they injected into torch with the DLL (of course it's also hard to know they exist without reading the source) and it probably needs them to perform well. ONNX with directml backend on AMD is very fast so there's nothing particularly wrong with it. The MS homepage also mentions that resource lifetime needs to be fully managed by the client program like anything that uses DX12 since it's a low level API, but it's difficult to find where any of that is exposed by torch_directml without digging through the C++ once again and it still isn't documented so you need experience with DX12 to have a clue what you're doing.

Adele101 commented 10 months ago

Hello, thank you for submitting this issue. While I can't provide a timeline for resolution as the moment, please know that your feedback is valuable to us. We will follow up once we can review this issue.