xl0 / lovely-tensors

Tensors, for human consumption
https://xl0.github.io/lovely-tensors
MIT License
1.11k stars 16 forks source link

Bug: many many graph breaks when using `torch.compile` #24

Closed baldassarreFe closed 4 months ago

baldassarreFe commented 4 months ago

Hello, thanks a lot for the library, I use it a lot and I love it!

After updating to a recent pytorch nightly I found that lovely tensors introduces many many graph breaks when using torch.compile. Below is a minimal reproduction:

import torch
import torchvision.models
import lovely_tensors

def main():
    print(torch.__version__)
    print(lovely_tensors.__version__)

    torch._logging.set_logs(graph_breaks=True)
    lovely_tensors.monkey_patch()

    model = torchvision.models.vit_b_16()
    model.cuda()

    x = torch.randn(256, 3, 224, 224, device="cuda")
    y = torch.compile(model)(x)
    print(y.shape)

if __name__ == "__main__":
    main()
python tmp.py 2>&1 | tee tmp.log

With the line lovely_tensors.monkey_patch() commented out, the output is very simply:

2.5.0.dev20240625
0.1.15
torch.Size([256, 1000])

But with monkey patching enabled I get warnings about graph breaks with a clear reference to StrProxy:

V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] WON'T CONVERT forward /.../python3.11/site-packages/torchvision/models/vision_transformer.py line 289
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] ========== TorchDynamo Stack Trace ==========
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] Traceback (most recent call last):
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_dynamo/output_graph.py", line 1396, in call_user_compiler
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     compiled_fn = compiler_fn(gm, self.example_inputs())
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     compiled_gm = compiler_fn(gm, example_inputs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/__init__.py", line 2141, in __call__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return compile_fx(model_, inputs_, config_patches=self.config)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/contextlib.py", line 81, in inner
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return func(*args, **kwds)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_inductor/compile_fx.py", line 1536, in compile_fx
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return aot_autograd(
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_functorch/aot_autograd.py", line 974, in aot_module_simplified
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     compiled_fn = dispatch_and_compile()
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                   ^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_functorch/aot_autograd.py", line 963, in dispatch_and_compile
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     compiled_fn, _ = create_aot_dispatcher_function(
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     r = func(*args, **kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]         ^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_functorch/aot_autograd.py", line 695, in create_aot_dispatcher_function
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     compiled_fn, fw_metadata = compiler_fn(
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                                ^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 266, in aot_dispatch_autograd
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 273, in aot_dispatch_autograd_graph
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     str(fw_metadata),
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     ^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/dataclasses.py", line 240, in wrapper
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     result = user_function(self)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]              ^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "<string>", line 3, in __repr__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/lovely_tensors/patch.py", line 33, in __repr__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return str(StrProxy(self))
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/lovely_tensors/repr_str.py", line 190, in __repr__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return to_str(self.t, plain=self.plain, verbose=self.verbose,
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return func(*args, **kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/lovely_tensors/repr_str.py", line 135, in to_str
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     if is_nasty(t) or not t.is_floating_point():
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]        ^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/lovely_tensors/repr_str.py", line 63, in is_nasty
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return (t_min.isnan() or t_min.isinf() or t_max.isinf()).item()
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return fn(*args, **kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return self.dispatch(func, types, args, kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1153, in _cached_dispatch_impl
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     output = self._dispatch_impl(func, types, args, kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1730, in _dispatch_impl
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_subclasses/fake_impls.py", line 150, in dispatch_to_op_implementations_dict
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]   File "/.../python3.11/site-packages/torch/_subclasses/fake_impls.py", line 375, in local_scalar_dense
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks]     raise DataDependentOutputException(func)
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] torch._subclasses.fake_tensor.DataDependentOutputException: aten._local_scalar_dense.default
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] 
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] ========== The above exception occurred while processing the following code ==========
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] 
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] 
V0627 torch/_dynamo/exc.py:209] [0/0] [__graph_breaks] ==========
W0627 torch/_dynamo/exc.py:210] [0/0] Backend compiler failed with a fake tensor exception at 
W0627 torch/_dynamo/exc.py:210] [0/0]   File "/.../python3.11/site-packages/torchvision/models/vision_transformer.py", line 305, in forward
W0627 torch/_dynamo/exc.py:210] [0/0]     return x
W0627 torch/_dynamo/exc.py:210] [0/0] Adding a graph break.

The output above is only the first of many graph breaks. I attached the full log: tmp.log

Can this library remain compatible with torch.compile?

xl0 commented 4 months ago

Looking into it. No, torch.compile should just work.

xl0 commented 4 months ago

image

Hmm, something must be broken in Pytorch 2.5.0. @baldassarreFe how do you install it? I never touched Pytorch Dev.

xl0 commented 4 months ago

I missed that you mentioned that you upgraded to pytorch 2.5.0. Yeah, I'm seeing the same issue.

xl0 commented 4 months ago

@baldassarreFe Please give the last git version a try:

pip install git+https://github.com/xl0/lovely-tensors

baldassarreFe commented 4 months ago

~Thanks for the quick fix. It still gives graph breaks in my env, same messages as before. I'll try again on monday with the newest nightly.~

Thanks for the quick fix, it works!