Closed Aurelius84 closed 2 months ago
out, indices = topk(x)。下游基于输出out得到的是bool类型的mask,以及indices flatten后的int,都是无法计算梯度的,导致在topk_grad调用时发现out_grad为未初始化的Tensor
torch也会报错,实验脚本如下:
import torch x = torch.randn([4,10]) x.requires_grad=True out, index = torch.topk(x, k=1, dim=0) # y1 = out > 1 y2 = index.flatten() loss = y2.sum() loss.backward() print(x.grad)
原因
out, indices = topk(x)。下游基于输出out得到的是bool类型的mask,以及indices flatten后的int,都是无法计算梯度的,导致在topk_grad调用时发现out_grad为未初始化的Tensor
竞品
torch也会报错,实验脚本如下: