Closed yongwww closed 2 weeks ago
extra transpose works.
with R.dataflow():
q = R.permute_dims(query, [0, 2, 1, 3])
k = R.permute_dims(key, [0, 2, 1, 3])
v = R.permute_dims(value, [0, 2, 1, 3])
r = R.nn.attention(q, k, v)
gv = R.permute_dims(r, [0, 2, 1, 3])
Currently, the Torch function
F.scaled_dot_product_attention
is mapped toR.nn.attention
in both the Relax nn.module and the FX converter. However, the inference results do not seem to match those obtained with PyTorch. Script to trigger the issue as below.