PaddlePaddle / PaddleTest

PaddlePaddle TestSuite
44 stars 114 forks source link

[CINN]Fix topk indices no grad problem #2935

Closed Aurelius84 closed 2 months ago

Aurelius84 commented 2 months ago

原因

out, indices = topk(x)。下游基于输出out得到的是bool类型的mask,以及indices flatten后的int,都是无法计算梯度的,导致在topk_grad调用时发现out_grad为未初始化的Tensor

竞品

torch也会报错,实验脚本如下:

import torch

x = torch.randn([4,10])
x.requires_grad=True
out, index = torch.topk(x, k=1, dim=0)
# y1 = out > 1
y2 = index.flatten()
loss = y2.sum()
loss.backward()

print(x.grad)