microsoft / DirectML

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.
MIT License
2.2k stars 293 forks source link

DirectMLX doesn't support 2 and 3 dimension GEMM #272

Open tom-huntington opened 2 years ago

tom-huntington commented 2 years ago

In DML_FEATURE_LEVEL_4_0 the GEMM operator additionally supported 2 and 3 dimensional tensors. However, DirectMLX hasn't been updated and still only supports 4 dimension tensors.

https://github.com/microsoft/DirectML/blob/master/Libraries/DirectMLX.h#L1946-L1950

Also, could you tell me if there is anyway to squeeze and unsqueeze tensors in DirectMLX? Or should I just always work with { 1, 1, height, width } dimension tensors?

adtsai commented 2 years ago

Hi, thanks for the report - we'll look into fixing this in DMLX. In the meantime, you can add/remove empty dimensions (dimensions of size 1, also known as squeeze/unsqueeze) by using the dml::Reinterpret function. e.g.

dml::Expression x = /*...*/; // Assume x has sizes { height, width }

auto y = dml::Reinterpret(x, { 1, 1, height, width }, std::nullopt)
// y now has sizes {1, 1, height, width}

In general you can use dml::Reinterpret to perform any kind of reshape - or even to reinterpret memory as a different type (e.g. you could write to a FLOAT32 tensor, and read out the raw bits by reinterpreting the tensor as UINT32 if you really wanted.)