FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
344 stars 48 forks source link

在跑通『LLaVA单卡+gems』过程总遇到的3个框架集成bug #301

Open 0x45f opened 1 week ago

0x45f commented 1 week ago

Describe the bug

编号 描述 截图
1 注册的AutogradCUDA算子和torch compile有冲突会报错,silu、native_dropout、tanh都有这样的问题,需要将注册的代码注释掉,暂时跳过 image
2 使用flag_gems.enable()会有dataloader的问题,但是with flag_gems.use_gems()没有问题,奇怪 image
3 optimizer阶段开启gems会有shape mismatch的问题 image
tongxin commented 6 days ago

Could you provide the diff that shows where you added flag_gems.enable()?

0x45f commented 6 days ago

Could you provide the diff that shows where you added flag_gems.enable()?

Using flag_gems.enable() or with flag_gems.use_gems(): here, flag_gems.enable() raise error image

0x45f commented 1 day ago

问题3定位到是torch.count_nonzero() API替换gems之后出现了问题。此API在aten下通过ne_scalar和sum_dim组合实现,最终定位到是gems下sum_dim的实现有问题。通过如下的代码可以复现sum_dim的问题

import torch
import flag_gems

torch.set_default_device('cuda')
flag_gems.enable()

x = torch.tensor([1, 2, 3, 4])
out = torch.sum(x, dim=[], keepdim=False)

print(out)

计算结果如下:

# aten的计算结果:
tensor(10, device='cuda:0')

#开gems后计算结果:
tensor([1, 2, 3, 4], device='cuda:0')
iclementine commented 1 day ago

看来是没有考虑 dimlist 的情况,sum 和 max.dim 这种不一样。

- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor

它的 dim 是可能为空的,或者是一个长度为 1 的 int[].

而 max.dim 则不会

- func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)