pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.59k stars 21.65k forks source link

Using autograd.Functions defined in torch/ cause graph breaks #118334

Open davidberard98 opened 5 months ago

davidberard98 commented 5 months ago

🐛 Describe the bug

If an autograd.Function is defined in torch/, it will cause a graph break due to skipfiles.

import torch
from torch.nested._internal.nested_tensor import ViewNestedFromBuffer

def fn(values, offsets):
    return ViewNestedFromBuffer.apply(values.cos(), offsets).sin()

fn_c = torch.compile(fn, backend="aot_eager", dynamic=True)
values = torch.rand((12, 8), requires_grad=True)
offsets = torch.tensor([0, 1, 2, 5, 8, 9, 12])
lengths = torch.tensor([1, 1, 3, 3, 1, 3])

fn_c(values, offsets)

log:

[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting autograd.Function, we were unable to trace function `forward` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
tion, we were unable to trace function `forward` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] 'inline in skipfiles: ViewNestedFromBuffer.forward | forward /data/users/dberard/pytorch/torch/nested/_internal/nested_tensor.py, skipped according skipfiles.SKIP_DIRS'
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 379, in speculate_subgraph
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     output = f.call_function(tx, args, sub_kwargs)
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/variables/functions.py", line 276, in call_function
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return super().call_function(tx, args, kwargs)
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/variables/functions.py", line 84, in call_function
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return tx.inline_user_function_return(
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/symbolic_convert.py", line 657, in inline_user_function_return
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     return cls.inline_call_(parent, func, args, kwargs)
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     result = InliningInstructionTranslator.check_inlineable(func)
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in check_inlineable
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     unimplemented(
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/data/users/dberard/pytorch/torch/_dynamo/exc.py", line 190, in unimplemented
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[2024-01-25 16:05:28,336] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: 'inline in skipfiles: ViewNestedFromBuffer.forward | forward /data/users/dberard/pytorch/torch/nested/_internal/nested_tensor.py, skipped according skipfiles.SKIP_DIRS'

note: ViewNestedFromBuffer may be patched; in which case, a different autograd.Function will probably need to be used to repro this behavior.

Versions

main branch ~jan 24.

cc @ezyang @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @msaroufim @bdhirsh @zou3519 @aakhundov

anijain2305 commented 5 months ago

Cc @yanboliang for the new tracing rule system

yanboliang commented 5 months ago

I have talked with @davidberard98 about this, I think we should allow Dynamo trace for all autograd.Function since the trace is only to check if it's sound. I'll work on a PR to fix this.

anijain2305 commented 1 month ago

Any update on this one?

davidberard98 commented 1 month ago

I haven't done this, but I think it may have been fixed otherwise - maybe @yanboliang did this? I think I've seen dynamo trying to inline ViewNestedFromBuffer (but failing because of other reasons), but I haven't looked too closely at it.