Open iclementine opened 2 weeks ago
Your backward implementation directly accesses tensor data pointers: you'll need to wrap that backward implementation in a custom op so we can trace it into a graph. The custom op manual has more details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
Your backward implementation directly accesses tensor data pointers: you'll need to wrap that backward implementation in a custom op so we can trace it into a graph. The custom op manual has more details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
Thanks for your reply. I get the point. Is it correct that the backward pass can only use operators (such that has Meta or fake kernels)?
I then wrap the function tanh_backward
as a custom operator and register a fake kernel for it. Thus, the operator has a Meta/Fake kernel and can be traced. But it still fails. Function tanh_forward
is still traced into with fake tensors on cuda and failed with the same error message. We expect the Meta kernel for aten::tanh
instead of tanh_forward
defined here to be traced into.(tanh
at torch/_refs/__init__.py
)
The code is modified as below.
import numpy as np
import torch
def tanh_forward(x: torch.Tensor) -> torch.Tensor:
x.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
return 2.0 * torch.sigmoid(2.0 * x) - 1.0
@torch.library.custom_op("mylib::tahn_backward", mutates_args=(), device_types="cuda")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
y.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
return dy * (1.0 - y ** 2)
@tanh_backward.register_fake
def tanh_backward_meta(x):
return torch.empty_like(x)
class Tanh(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
out = tanh_forward(A)
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, out_grad):
(out,) = ctx.saved_tensors
in_grad = tanh_backward(out, out_grad)
return in_grad
def tanh(A):
return Tanh.apply(A)
aten_lib = torch.library.Library("aten", "IMPL")
aten_lib.impl("tanh", tanh, "AutogradCUDA")
def f(x):
return torch.tanh(x)
x = torch.randn(10, device="cuda")
F = torch.compile(f)
out = F(x)
print(out)
If we wrap all operations that access data pointer into some custom operator with fake kernel registered, the problem can be avoided. But the real problem that confuses me is:
When we register a kernel for aten builtin operators with AutogradCUDA
key, why the operator's Meta kernel is not used in dynamo tracing, instead, the kernel for AutogradCUDA
key is used?
For custom op, we have to register a fake kernel for it. But for a builtin operator with Meta kernel, why the meta kernel is not used any more after we register a kernel for AutogradCUDA
key?
I also try adding a AutogradCPU kernel and get a similar result. The AutogradCPU kernel is traced into with fake tensor on CPU and fail with similar reason.
import numpy as np
import torch
def tanh_forward(x: torch.Tensor) -> torch.Tensor:
x.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
return 2.0 * torch.sigmoid(2.0 * x) - 1.0
@torch.library.custom_op("mylib::tahn_backward", mutates_args=(), device_types="cpu")
def tanh_backward(y: torch.Tensor, dy: torch.Tensor) -> torch.Tensor:
y.data_ptr() # simulate a custom kernel that typically accaes its inputs data pointer
return dy * (1.0 - y ** 2)
@tanh_backward.register_fake
def tanh_backward_meta(x):
return torch.empty_like(x)
class Tanh(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
out = tanh_forward(A)
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, out_grad):
(out,) = ctx.saved_tensors
in_grad = tanh_backward(out, out_grad)
return in_grad
def tanh(A):
return Tanh.apply(A)
aten_lib = torch.library.Library("aten", "IMPL")
aten_lib.impl("tanh", tanh, "AutogradCPU")
def f(x):
return torch.tanh(x)
x = torch.randn(10, device="cpu")
F = torch.compile(f)
out = F(x)
print(out)
So I suspect that it is how fake tensor works that results in the problem.
By enabling dispatch trace by TORCH_SHOW_DISPATCH_TRACE=1
, we can see the trace:
[call] op=[aten::clone], key=[AutogradCPU]
[redispatch] op=[aten::clone], key=[CPU]
[call] op=[aten::empty_like], key=[CPU]
[call] op=[aten::empty.memory_format], key=[BackendSelect]
[redispatch] op=[aten::empty.memory_format], key=[CPU]
[call] op=[aten::copy_], key=[CPU]
When only registering with "CPU" key, the trace related to tanh is
[call] op=[aten::tanh], key=[PythonDispatcher]
[redispatchBoxed] op=[aten::tanh], key=[PythonDispatcher]
[redispatch] op=[aten::tanh], key=[PythonDispatcher]
[callBoxed] op=[aten::tanh], key=[PythonDispatcher]
[callBoxed] op=[prims::tanh], key=[PythonDispatcher]
[call] op=[aten::empty_permuted], key=[PythonDispatcher]
[redispatch] op=[aten::empty_permuted], key=[Meta]
[call] op=[aten::empty.memory_format], key=[PythonDispatcher]
[redispatch] op=[aten::empty.memory_format], key=[Meta]
[call] op=[aten::as_strided], key=[PythonDispatcher]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::empty_strided], key=[PythonDispatcher]
[redispatch] op=[aten::empty_strided], key=[Meta]
[call] op=[aten::detach], key=[Meta]
I think that AutogradCPU
or AutogradCUDA
has a higher priority and when the cooresponding kernel is used, kernel for Meta is not used.
@iclementine taking a step back, what are you trying to do? This (overriding the autograd kernel for aten builtins) isn't a typical thing we expect users to do.
@iclementine taking a step back, what are you trying to do? This (overriding the autograd kernel for aten builtins) isn't a typical thing we expect users to do.
Yes, I know it. We are trying to add Triton kernels to aten operators, so these kernels can be called by just using torch.op_name
. These kernels can thus "replace" CUDA kernels at runtime.
And as we are doing this, we find that there are backend-specific autograd keys, which can bundle the forward & backward of an operator while being restricted to the specific backend. This provided a pragmatic way to overwrite backward of some operators within some backends.
For example, when an operator a
has its backward(derivatives) defined in the derivatives.yaml
, the formula is backend-agnostic. Say, the backward of a
is composed of some other operators, dx = b(x) + c(x)*d(x)
, if we have a fused operator for backward of operator a
and want to use that instead, we have to make backward_a
an operator(a.k.a. add an entry for backward_a in native_functions.yaml
) and change the formula for the derivative of a
into dx = backward_a(x)
.
I appreciate the flexibility that torch offers by this kind of moving back-and-forth to make autograd work with operators. And there are actually a few operators for backward (e.g. batch_norm_backward) in native_functions.yaml
.
But if we want to use that fused operator as the backward of operator a
without changing the formula (which would affect other backends, too), we can use backend-specific autograd key to register an implementation with forward kernel and backward kernel bundled together.
I know it is kind of unusual and there is no such example in torch repository. It may be designed for some exceptional use (AutogradXLA for example).
I can change the way to do this by not using backend-specific autograd keys. But I want to know if fake tensor can handle this case better, by also creating fake tensor on meta
instead of a specific backend(Or to be clear, is it a bug or a feature)?
Thank you.
You basically want to use register_autograd (https://pytorch.org/docs/stable/library.html#torch.library.register_autograd). But it might disallow you from registering a backward for an aten op, so try to comment that check out and let me know how it goes
@zou3519 thank you.
I see, register_autograd
basically register a kernel for Autograd
key, which applys for all backends, since the key part of it is:
lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
It is similar to adding an entry in derivatives.yaml
. So I think backend-specific autograd kernel is something rarely used. And custom_op
related APIs do not encourage users to do so, even for custom operators.
🐛 Describe the bug
In the following code we add a backend-specific autograd kernel(e.g. AutogradCUDA) for aten builtin operator "tanh". Originally, it does not have a kernel for "AutogradCUDA", but a "CUDA" kernel. The derivative of tanh is defiend in
derivatives.yaml
. We subclasstorch.autograd.Function
and register our implementation with "AutogradCUDA" key to make sure that our self-defined backward is used simply by usingtorch.tanh
.Then if we try using it with
torch.compile
, we get an error.The main issue is that when doing dynamo tracing with FakeTensors(which should be fake tensors of meta tensors, so only the Meta kernels are called), the self-defined
tanh
above is called with fake tensors of tensor on cuda device as inputs.FakeTensor(..., device='cuda:0', size=(10,))
The error message is
Since we are calling
.data_ptr()
on a fake tensor.When only register our self-defined tanh with "CUDA" key(replacing the last line with
aten_lib.impl("tanh", tanh, "CUDA")
), the above code runs successfully. Adding a breakpoint at tanh's Meta kernel attanh
attorch/_refs/__init__.py
, we can see that the input argument is a fake tensor on meta device.FakeTensor(..., device='meta', size=(10,))
This is the main reason why the kernel for
AutogradCUDA
is called.(Since the fake tensor is on CUDA instead of Meta device).Is there something wrong with the way in which FakeTensors work with aten library, or more specifically, backend-specific autograd kernels? Thank you.
Versions
cc @ezyang @chauhang @penguinwu @zou3519 @bdhirsh @yf225