FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
347 stars 48 forks source link

[Operator] fix: fix pytorch 2.4 tanh backward bug #263

Open yinfan98 opened 1 month ago

yinfan98 commented 1 month ago

PR Category

Operator

Type of Change

Bug Fix

Description

根据issue 249修复 bug

问题代码:

import flag_gems
import torch

flag_gems.enable()

def f(x, y):
    a = torch.tanh(y)
    b = x - y
    return flag_gems.fused.gelu_and_mul(a, b)

x = torch.randn(10,device="cuda")
y = torch.randn(10,device="cuda")

F = torch.compile(f)

print(F(x, y))

同时也怀疑其他使用torch.autograd.Function 和 AutogradCUDA 的wrapper里也需要修复。

Issue

https://github.com/FlagOpen/FlagGems/issues/249

Progress

Performance

StrongSpoon commented 1 month ago

谢谢贡献! FakeTensor不支持的情况与我们的结论一致,然而考虑到对多版本PyTorch的兼容需求,custom_op的方法可能暂时不便合入。

yinfan98 commented 1 month ago

谢谢贡献! FakeTensor不支持的情况与我们的结论一致,然而考虑到对多版本PyTorch的兼容需求,custom_op的方法可能暂时不便合入。

@StrongSpoon 理解,所以我现在的版本没有用custom op。custom op的代码已经注释了,可能在代码diff里有点看不清。。。

这里的做法是在gems namespace上再把triton算子注册一遍,来保证经过torch compile时没有问题。其实这是make sense的。因为在torch原生的代码里,像gelu这个算子抓去到最底层也会得到类似torch.ops.aten.gelu_backward 这样的东西。

@torch.library.impl("gems::tanh_forward", "cuda")
def tanh_forward(x: torch.Tensor) -> torch.Tensor:
    return tanh_forward_kernel(x)

@torch.library.impl("gems::tanh_backward", "cuda")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
    return tanh_backward_kernel(y, dy)

@torch.library.impl_abstract("gems::tanh_forward")
def fake_tanh_forward(x: torch.Tensor) -> torch.Tensor:
    return x

@torch.library.impl_abstract("gems::tanh_backward")
def fake_tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
    return dy

现在的做法使用了是可以兼容 From torch2.2 To torch2.4的。但是看pytorch这几个版本的更新自定义算子还是占大头,所以在注释里提供了一个更新的解决方案~

tongxin commented 1 month ago

Thanks so much. @StrongSpoon we need to dig into how torch.compile resolves the mapping from these decorated functions to its own lowering targets.

Bowen12992 commented 1 month ago

We need to figure out why the previous solution went wrong and whether the new solution will cause performance loss.

yinfan98 commented 1 month ago

We need to figure out why the previous solution went wrong and whether the new solution will cause performance loss.

Get it

StrongSpoon commented 1 month ago

you could reformat the code according to CONTRIBUTING.md first :)

iclementine commented 1 month ago

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

yinfan98 commented 1 month ago

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@iclementine 我的理解是在init.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试:

def tanh(A: torch.Tensor):
    print("using gems tanh")
    return Tanh.apply(A)

Terminal:

using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543,  0.0186,
        -0.2310, -0.0073], device='cuda:0')
iclementine commented 1 month ago

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@iclementine 我的理解是在init.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试:

def tanh(A: torch.Tensor):
    print("using gems tanh")
    return Tanh.apply(A)

Terminal:

using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543,  0.0186,
        -0.2310, -0.0073], device='cuda:0')

了解了。实际上是注册了一个新的 op gems::tanh_forward.

不过 impl 这样的用法并不像 custom_op 那样返回一个 tanh_forward。所以tanh_forward 这个 name 的值成了 None.

使用的时候有一层间接。

torch.ops.aten.tanh -> flag_gems.ops.tanh 函数 torch.ops.gems.tanh_forward -> 上面定义的函数

yinfan98 commented 1 month ago

不过这个做法是定义了一个新的 op 了。gems::tanh_forwardaten::tanh 不同。那么这么做直接使用 torch.tanh 就不会分发到我们定义的函数了吧?

@iclementine 我的理解是在init.py里定义了torch.library.Library 替换的aten,并且在impl里让最外层的tanh算子替换了。内部注册的ops其实只在内部用到了,并不会影响其他的。所以会分发到我们定义的tanh上。同时附上一份测试:

def tanh(A: torch.Tensor):
    print("using gems tanh")
    return Tanh.apply(A)

Terminal:

using gems tanh
using gems tanh
using gems tanh
using gems tanh
tensor([ 0.1863, -1.5428, -0.3257, -0.1186, -1.1268, -0.1928, -0.0543,  0.0186,
        -0.2310, -0.0073], device='cuda:0')

了解了。实际上是注册了一个新的 op gems::tanh_forward.

不过 impl 这样的用法并不像 custom_op 那样返回一个 tanh_forward。所以tanh_forward 这个 name 的值成了 None.

使用的时候有一层间接。

torch.ops.aten.tanh -> flag_gems.ops.tanh 函数 torch.ops.gems.tanh_forward -> 上面定义的函数

是的,是这样。

yinfan98 commented 1 month ago

you could reformat the code according to CONTRIBUTING.md first :)

Hi, I fixed according to the code style in CONTRIBUTING.md.

yinfan98 commented 4 weeks ago

Hi, when I was looking at the flash-attn repository, I noticed that they handle the interfaces like this.

if torch.__version__ >= "2.4.0":
    _torch_custom_op_wrapper = torch.library.custom_op
    _torch_register_fake_wrapper = torch.library.register_fake
else:
    def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
        def wrap(func):
            return func
        if fn is None:
            return wrap
        return fn
    def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
        def wrap(func):
            return func
        if fn is None:
            return wrap
        return fn
    _torch_custom_op_wrapper = noop_custom_op_wrapper
    _torch_register_fake_wrapper = noop_register_fake_wrapper

and both register fake impl and real impl like.

@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
def _flash_attn_forward_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout_p: float,
    softmax_scale: float,
    causal: bool,
    window_size_left: int,
    window_size_right: int,
    softcap: float,
    alibi_slopes: Optional[torch.Tensor],
    return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    # some fake impl
    ...

And then use the custom op to torch.autograd.Function. Maybe we could reference this implementation. cc: @StrongSpoon @tongxin @Bowen12992 @iclementine

tongxin commented 3 weeks ago

I think we were obviously focused on custom ops and overlooked the real cause that gives rise to the error. It's not autograd.Function that impedes symbolic tracing. I was able to write a simple function using autograd.Function, registered to aten and passed torch compile. It proves that using the low level API and be compatible with torch compile is plausible. Declaring functions under gems:: namespace is not necessary although it might be a good idea otherwise.

The real issue is _base.data_ptr() in StridedBuffer is not traceable with FakeTensor. I guess we need a symbolic wrapper for StridedBuffer instead.

yinfan98 commented 3 weeks ago

I think we were obviously focused on custom ops and overlooked the real cause that gives rise to the error. It's not autograd.Function that impedes symbolic tracing. I was able to write a simple function using autograd.Function, registered to aten and passed torch compile. It proves that using the low level API and be compatible with torch compile is plausible. Declaring functions under gems:: namespace is not necessary although it might be a good idea otherwise.

The real issue is _base.data_ptr() in StridedBuffer is not traceable with FakeTensor. I guess we need a symbolic wrapper for StridedBuffer instead.

Thank you for your response. I tried several methods to wrap StridedBuffer using torch.fx, but they were unsuccessful. Despite implementing various wrappers, FakeTensor was still being caught in the Triton code. Furthermore, I also believe this is not an autograd.Function issue, but rather that the choice between AutogradCUDA and CUDA during registration determines whether this problem occurs. I think at the capture stage, different registration methods affect whether the generation can be successfully called (for example, the cos operator doesn't encounter these issues). Using torch.compile will ultimately rely on dynamo, and I don't fully understand the deeper details of dynamo. Moreover, dynamo has been continuously updating in recent versions. Since I'm not very familiar with FlagGems underlying Triton code generation, for now, I can only proceed with the registration method.

And I viewed some PyTorch 2.5.0 Release, this PR maybe related to out problem: https://github.com/pytorch/pytorch/pull/133125

iclementine commented 3 weeks ago

Thanks for the additional information. I am also digging into the details of fake tensor and meta device and their relations. We would get some results soon.

iclementine commented 3 weeks ago

We have found the reason. When registering a backend-specific autograd kernel:

Fake tensors on CUDA device has been passed into an AutogradCUDA kernel.

That is not an expected behavior.

In an ideal case, fake tensors on Meta device passed into Meta kernels. And everything works.

I don't know exactly why there are fake tensors on cuda involved in the case when compiling a function with aten operators with backend-specific autograd kernels.

I have opend an issue on pytorch repository: https://github.com/pytorch/pytorch/issues/139707.

Adding another layer of indirection does not really solve this problem. There are still fake tensors on cuda passed to the torch.autograd.Function. It just happens that the function can now handles them. If you add a x.data_ptr() into the torch.autograd.Function's forward method, it fails again. This is a usual usage in custom kernels since accessing data pointer is commonplace to do.

yinfan98 commented 3 weeks ago

We have found the reason. When registering a backend-specific autograd kernel:

Fake tensors on CUDA device has been passed into AutogradCUDA kernel.

That is not an expected behavior.

In an ideal case, fake tensors on Meta device passed into Meta kernels. And everything works.

I don't know exactly why there are fake tensors on cuda involved in the case when compiling a function with aten operators with backend-specific autograd kernels.

I have opend an issue on pytorch repository: pytorch/pytorch#139707.

Adding another layer of indirection does not really solve this problem. There are still fake tensors on cuda passed to the torch.autograd.Function. It just happens that the function can now handles them. If you add a x.data_ptr() into the torch.autograd.Function's forward method, it fails again. This is a usual usage in custom kernels since accessing data pointer is commonplace to do.

cool!