apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.41k stars 3.4k forks source link

[Bug] correctness issue of scaled_dot_product_attention #17099

Closed yongwww closed 2 weeks ago

yongwww commented 2 weeks ago

Currently, the Torch function F.scaled_dot_product_attention is mapped to R.nn.attention in both the Relax nn.module and the FX converter. However, the inference results do not seem to match those obtained with PyTorch. Script to trigger the issue as below.

import numpy as np
import torch
import torch.nn.functional as F

import tvm
from tvm import relax
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R

# TVM: commit cc7eb2faae3444ee02b142a5aea237dd1db6d29a
# Torch: 2.3.1
np.random.seed(11111)
shape = (2, 24, 4250, 64)
dtype = "float32"

qkv_pt = []
qkv_tvm = []
for _ in range(3):
    val = np.random.normal(loc=0, scale=1, size=shape).astype(dtype)
    qkv_pt.append(torch.tensor(val).to("cuda"))
    qkv_tvm.append(tvm.nd.array(val, tvm.cuda()))

@I.ir_module
class Model:
    @R.function
    def main(
        query: R.Tensor(shape, dtype=dtype),
        key: R.Tensor(shape, dtype=dtype),
        value: R.Tensor(shape, dtype=dtype),
    ) -> R.Tensor(shape, dtype=dtype):
        R.func_attr({"num_input": 3})
        with R.dataflow():
            gv: R.Tensor(shape, dtype=dtype) = R.nn.attention(query, key, value)
            R.output(gv)
        return gv

target = tvm.target.Target("cuda")  # "nvidia/nvidia-a100"
ex = relax.build(Model, target)
vm = relax.VirtualMachine(ex, tvm.cuda())
out_tvm = vm["main"](*qkv_tvm)

out_pt = F.scaled_dot_product_attention(*qkv_pt)
print(f"tvm output: {out_tvm}\ntorch output: {out_pt}")
np.testing.assert_allclose(out_tvm.numpy(), out_pt.cpu().numpy(), atol=1e-2)
# Mismatched elements: 12702413 / 13056000 (97.3%)
yongwww commented 2 weeks ago

extra transpose works.

        with R.dataflow():
            q = R.permute_dims(query, [0, 2, 1, 3])
            k = R.permute_dims(key, [0, 2, 1, 3])
            v = R.permute_dims(value, [0, 2, 1, 3])
            r = R.nn.attention(q, k, v)
            gv = R.permute_dims(r, [0, 2, 1, 3])