microsoft / DirectML

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.
MIT License
2.17k stars 288 forks source link

Torch-DirectML Layer Norm Produces Incorrect Result with Non-contiguous Input #588

Closed NullSenseStudio closed 2 months ago

NullSenseStudio commented 3 months ago

torch-directml: 0.2.1.dev240521 python: 3.11.7

import torch
import torch_directml
from torch.nn.functional import layer_norm

device = torch_directml.device()

input = torch.randn(4, 4, 4)

weight = torch.randn(4)
bias = torch.randn(4)

cpu_output = layer_norm(input.permute(0, 2, 1), [4], weight, bias)

dml_output = layer_norm(input.to(device).permute(0, 2, 1), [4], weight.to(device), bias.to(device)).cpu()
print((cpu_output-dml_output).std().item(), (cpu_output-dml_output).abs().max().item())

0.6974620223045349 3.6198389530181885 The result isn't close to what it is on CPU or other devices.

dml_output = layer_norm(input.to(device).permute(0, 2, 1).contiguous(), [4], weight.to(device), bias.to(device)).cpu()
print((cpu_output-dml_output).std().item(), (cpu_output-dml_output).abs().max().item())

7.381072464340832e-08 2.384185791015625e-07 But it will work as expected as long as it is made contiguous.

dml_output = layer_norm(input.to(device)[::2], [4], weight.to(device), bias.to(device)).cpu()

This non-contiguous input will cause an error instead.

File "...\torch\nn\functional.py", line 2546, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: m_device->CreateOperator(&opDesc, IID_PPV_ARGS(&op))
joshjkim commented 3 months ago

Hi @NullSenseStudio , thanks for your feedback. We'll be including a fix to address non-contiguous inputs in layer_norm in our upcoming torch-directml build releasing soon.

joshjkim commented 2 months ago

@NullSenseStudio We just released our new build that addresses the layer_norm issue. Please pip install torch-directml --upgrade to update to torch-directml 0.2.2.dev240614