import flag_gems
import torch
def cat_negative_dim_test():
x = torch.randn(3, 2, device="cuda")
y = torch.randn(3, 3, device="cuda")
return flag_gems.cat((x, y, x), -1)
# return torch.cat((x, y, x), -1)
y = cat_negative_dim_test()
print(y)
报错如下:
File "/work/FlagGems/src/flag_gems/ops/cat.py", line 41, in cat
raise RuntimeError(
RuntimeError: Sizes of tensors must match except in dimension -1. Expected size 2 but got size 3 for tensor number 1 in the list
concat (cat) 在运行时,dim 为负数的时候有概率会报错,下面是复现代码,使用注释掉的 torch 的 aten 算子运行不会报错
报错如下: