FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
347 stars 49 forks source link

Code Contribution: 【Lv1】【Bug Fix】- Fixed the issue where the cat operator would throw an error when the dim is a negative number. #248

Closed Bowen12992 closed 1 month ago

Bowen12992 commented 1 month ago

concat (cat) 在运行时,dim 为负数的时候有概率会报错,下面是复现代码,使用注释掉的 torch 的 aten 算子运行不会报错

import flag_gems
import torch

def cat_negative_dim_test():
    x = torch.randn(3, 2, device="cuda")
    y = torch.randn(3, 3, device="cuda")
    return flag_gems.cat((x, y, x), -1)
    # return torch.cat((x, y, x), -1)

y = cat_negative_dim_test()
print(y)

报错如下:

  File "/work/FlagGems/src/flag_gems/ops/cat.py", line 41, in cat
    raise RuntimeError(
RuntimeError: Sizes of tensors must match except in dimension -1. Expected size 2 but got size 3 for tensor number 1 in the list
Tango2018cc commented 1 month ago

认领方式:

  1. Issue中留言评论:「认领人名称+认领」
  2. 在贡献者微信群内发送:认领的Issue号例如:「认领+[https://github.com/FlagOpen/FlagGems/issues/xxx」]
2niuhe commented 1 month ago

唐康 认领

Tango2018cc commented 1 month ago

领取后,请于下周1前完成,谢谢

2niuhe commented 1 month ago

PR: https://github.com/FlagOpen/FlagGems/pull/261