Open 0x45f opened 1 week ago
Could you provide the diff that shows where you added flag_gems.enable()
?
Could you provide the diff that shows where you added
flag_gems.enable()
?
Using flag_gems.enable()
or with flag_gems.use_gems():
here, flag_gems.enable()
raise error
问题3定位到是torch.count_nonzero() API替换gems之后出现了问题。此API在aten下通过ne_scalar和sum_dim组合实现,最终定位到是gems下sum_dim的实现有问题。通过如下的代码可以复现sum_dim的问题
import torch
import flag_gems
torch.set_default_device('cuda')
flag_gems.enable()
x = torch.tensor([1, 2, 3, 4])
out = torch.sum(x, dim=[], keepdim=False)
print(out)
计算结果如下:
# aten的计算结果:
tensor(10, device='cuda:0')
#开gems后计算结果:
tensor([1, 2, 3, 4], device='cuda:0')
看来是没有考虑 dimlist 的情况,sum 和 max.dim 这种不一样。
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
它的 dim 是可能为空的,或者是一个长度为 1 的 int[].
而 max.dim 则不会
- func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
Describe the bug