pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.39k stars 22.73k forks source link

Adding backend-specific autograd kernel for aten builtin operators breaks `torch.compile` #139707

Open iclementine opened 2 weeks ago

iclementine commented 2 weeks ago

🐛 Describe the bug

In the following code we add a backend-specific autograd kernel(e.g. AutogradCUDA) for aten builtin operator "tanh". Originally, it does not have a kernel for "AutogradCUDA", but a "CUDA" kernel. The derivative of tanh is defiend in derivatives.yaml. We subclass torch.autograd.Function and register our implementation with "AutogradCUDA" key to make sure that our self-defined backward is used simply by using torch.tanh.

import torch

def tanh_forward(x):
    x.data_ptr() # simulate a custom kernel that typically accesses its inputs data pointer
    return 2.0 * torch.sigmoid(2.0 * x) - 1.0

def tanh_backward(y, dy):
    y.data_ptr() # simulate a custom kernel that typically accesses its inputs data pointer
    return dy * (1.0 - y ** 2)

class Tanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A):
        out = tanh_forward(A)
        ctx.save_for_backward(out)
        return out

    @staticmethod
    def backward(ctx, out_grad):
        (out,) = ctx.saved_tensors
        in_grad = tanh_backward(out, out_grad)
        return in_grad

def tanh(A):
    return Tanh.apply(A)

aten_lib = torch.library.Library("aten", "IMPL")
aten_lib.impl("tanh", tanh, "AutogradCUDA")

Then if we try using it with torch.compile, we get an error.

def f(x):
    return torch.tanh(x)

x = torch.randn(10, device="cuda")
F = torch.compile(f)
out = F(x)
print(out)

The main issue is that when doing dynamo tracing with FakeTensors(which should be fake tensors of meta tensors, so only the Meta kernels are called), the self-defined tanh above is called with fake tensors of tensor on cuda device as inputs.

FakeTensor(..., device='cuda:0', size=(10,))

The error message is

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method tanh of type object at 0x7f067d585a40>(*(FakeTensor(..., device='cuda:0', size=(10,)),), **{}):
Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

Since we are calling .data_ptr() on a fake tensor.

When only register our self-defined tanh with "CUDA" key(replacing the last line with aten_lib.impl("tanh", tanh, "CUDA")), the above code runs successfully. Adding a breakpoint at tanh's Meta kernel at tanh at torch/_refs/__init__.py, we can see that the input argument is a fake tensor on meta device.

FakeTensor(..., device='meta', size=(10,))

This is the main reason why the kernel for AutogradCUDA is called.(Since the fake tensor is on CUDA instead of Meta device).

Is there something wrong with the way in which FakeTensors work with aten library, or more specifically, backend-specific autograd kernels? Thank you.

Versions

Collecting environment information...
PyTorch version: 2.5.0a0+git32f585d
Is debug build: True
CUDA used to build PyTorch: 12.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-97-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 545.23.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             80
On-line CPU(s) list:                0-79
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Silver 4316 CPU @ 2.30GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 2
Core(s) per socket:                 20
Socket(s):                          2
Stepping:                           6
CPU max MHz:                        3400.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4600.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          1.9 MiB (40 instances)
L1i cache:                          1.3 MiB (40 instances)
L2 cache:                           50 MiB (40 instances)
L3 cache:                           60 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-19,40-59
NUMA node1 CPU(s):                  20-39,60-79
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] optree==0.10.0
[pip3] pytorch-triton==3.1.0+5fe38ffd73
[pip3] torch==2.5.0a0+git32f585d
[conda] Could not collect

cc @ezyang @chauhang @penguinwu @zou3519 @bdhirsh @yf225

bdhirsh commented 2 weeks ago

Your backward implementation directly accesses tensor data pointers: you'll need to wrap that backward implementation in a custom op so we can trace it into a graph. The custom op manual has more details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

iclementine commented 1 week ago

Your backward implementation directly accesses tensor data pointers: you'll need to wrap that backward implementation in a custom op so we can trace it into a graph. The custom op manual has more details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

Thanks for your reply. I get the point. Is it correct that the backward pass can only use operators (such that has Meta or fake kernels)?

I then wrap the function tanh_backward as a custom operator and register a fake kernel for it. Thus, the operator has a Meta/Fake kernel and can be traced. But it still fails. Function tanh_forward is still traced into with fake tensors on cuda and failed with the same error message. We expect the Meta kernel for aten::tanh instead of tanh_forward defined here to be traced into.(tanh at torch/_refs/__init__.py)

The code is modified as below.

import numpy as np
import torch

def tanh_forward(x: torch.Tensor) -> torch.Tensor:
    x.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
    return 2.0 * torch.sigmoid(2.0 * x) - 1.0

@torch.library.custom_op("mylib::tahn_backward", mutates_args=(), device_types="cuda")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
    y.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
    return dy * (1.0 - y ** 2)

@tanh_backward.register_fake
def tanh_backward_meta(x):
    return torch.empty_like(x)

class Tanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A):
        out = tanh_forward(A)
        ctx.save_for_backward(out)
        return out

    @staticmethod
    def backward(ctx, out_grad):
        (out,) = ctx.saved_tensors
        in_grad = tanh_backward(out, out_grad)
        return in_grad

def tanh(A):
    return Tanh.apply(A)

aten_lib = torch.library.Library("aten", "IMPL")
aten_lib.impl("tanh", tanh, "AutogradCUDA")

def f(x):
    return torch.tanh(x)

x = torch.randn(10, device="cuda")
F = torch.compile(f)
out = F(x)
print(out)

If we wrap all operations that access data pointer into some custom operator with fake kernel registered, the problem can be avoided. But the real problem that confuses me is:

When we register a kernel for aten builtin operators with AutogradCUDA key, why the operator's Meta kernel is not used in dynamo tracing, instead, the kernel for AutogradCUDA key is used?

For custom op, we have to register a fake kernel for it. But for a builtin operator with Meta kernel, why the meta kernel is not used any more after we register a kernel for AutogradCUDA key?

iclementine commented 1 week ago

I also try adding a AutogradCPU kernel and get a similar result. The AutogradCPU kernel is traced into with fake tensor on CPU and fail with similar reason.

import numpy as np
import torch

def tanh_forward(x: torch.Tensor) -> torch.Tensor:
    x.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
    return 2.0 * torch.sigmoid(2.0 * x) - 1.0

@torch.library.custom_op("mylib::tahn_backward", mutates_args=(), device_types="cpu")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
    y.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
    return dy * (1.0 - y ** 2)

@tanh_backward.register_fake
def tanh_backward_meta(x):
    return torch.empty_like(x)

class Tanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A):
        out = tanh_forward(A)
        ctx.save_for_backward(out)
        return out

    @staticmethod
    def backward(ctx, out_grad):
        (out,) = ctx.saved_tensors
        in_grad = tanh_backward(out, out_grad)
        return in_grad

def tanh(A):
    return Tanh.apply(A)

aten_lib = torch.library.Library("aten", "IMPL")
aten_lib.impl("tanh", tanh, "AutogradCPU")

def f(x):
    return torch.tanh(x)

x = torch.randn(10, device="cpu")
F = torch.compile(f)
out = F(x)
print(out)

So I suspect that it is how fake tensor works that results in the problem.

By enabling dispatch trace by TORCH_SHOW_DISPATCH_TRACE=1, we can see the trace:

[call] op=[aten::clone], key=[AutogradCPU]
  [redispatch] op=[aten::clone], key=[CPU]
   [call] op=[aten::empty_like], key=[CPU]
    [call] op=[aten::empty.memory_format], key=[BackendSelect]
     [redispatch] op=[aten::empty.memory_format], key=[CPU]
   [call] op=[aten::copy_], key=[CPU]

When only registering with "CPU" key, the trace related to tanh is

 [call] op=[aten::tanh], key=[PythonDispatcher]
  [redispatchBoxed] op=[aten::tanh], key=[PythonDispatcher]
   [redispatch] op=[aten::tanh], key=[PythonDispatcher]
    [callBoxed] op=[aten::tanh], key=[PythonDispatcher]
     [callBoxed] op=[prims::tanh], key=[PythonDispatcher]
      [call] op=[aten::empty_permuted], key=[PythonDispatcher]
       [redispatch] op=[aten::empty_permuted], key=[Meta]
        [call] op=[aten::empty.memory_format], key=[PythonDispatcher]
         [redispatch] op=[aten::empty.memory_format], key=[Meta]
        [call] op=[aten::as_strided], key=[PythonDispatcher]
    [call] op=[aten::detach], key=[Meta]
    [call] op=[aten::empty_strided], key=[PythonDispatcher]
     [redispatch] op=[aten::empty_strided], key=[Meta]
    [call] op=[aten::detach], key=[Meta]

I think that AutogradCPU or AutogradCUDA has a higher priority and when the cooresponding kernel is used, kernel for Meta is not used.

zou3519 commented 1 week ago

@iclementine taking a step back, what are you trying to do? This (overriding the autograd kernel for aten builtins) isn't a typical thing we expect users to do.

iclementine commented 1 week ago

@iclementine taking a step back, what are you trying to do? This (overriding the autograd kernel for aten builtins) isn't a typical thing we expect users to do.

Yes, I know it. We are trying to add Triton kernels to aten operators, so these kernels can be called by just using torch.op_name. These kernels can thus "replace" CUDA kernels at runtime.

And as we are doing this, we find that there are backend-specific autograd keys, which can bundle the forward & backward of an operator while being restricted to the specific backend. This provided a pragmatic way to overwrite backward of some operators within some backends.

For example, when an operator a has its backward(derivatives) defined in the derivatives.yaml, the formula is backend-agnostic. Say, the backward of a is composed of some other operators, dx = b(x) + c(x)*d(x), if we have a fused operator for backward of operator a and want to use that instead, we have to make backward_a an operator(a.k.a. add an entry for backward_a in native_functions.yaml) and change the formula for the derivative of a into dx = backward_a(x).

I appreciate the flexibility that torch offers by this kind of moving back-and-forth to make autograd work with operators. And there are actually a few operators for backward (e.g. batch_norm_backward) in native_functions.yaml.

But if we want to use that fused operator as the backward of operator a without changing the formula (which would affect other backends, too), we can use backend-specific autograd key to register an implementation with forward kernel and backward kernel bundled together.

I know it is kind of unusual and there is no such example in torch repository. It may be designed for some exceptional use (AutogradXLA for example).

I can change the way to do this by not using backend-specific autograd keys. But I want to know if fake tensor can handle this case better, by also creating fake tensor on meta instead of a specific backend(Or to be clear, is it a bug or a feature)?

Thank you.

zou3519 commented 1 week ago

You basically want to use register_autograd (https://pytorch.org/docs/stable/library.html#torch.library.register_autograd). But it might disallow you from registering a backward for an aten op, so try to comment that check out and let me know how it goes

iclementine commented 1 week ago

@zou3519 thank you. I see, register_autograd basically register a kernel for Autograd key, which applys for all backends, since the key part of it is:

lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)

It is similar to adding an entry in derivatives.yaml. So I think backend-specific autograd kernel is something rarely used. And custom_op related APIs do not encourage users to do so, even for custom operators.