pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 123 forks source link

Following pattern in code seems to break dynamo graph tracing #1695

Closed ngoyal2707 closed 2 years ago

ngoyal2707 commented 2 years ago
import torch
import torchdynamo

from torch import nn

class TestModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linears = nn.Sequential(
            nn.Linear(100, 100),
            nn.Linear(100, 100),
            nn.Linear(100, 100),
            nn.Linear(100, 100),
        )

    def forward(self, x):
        all_but_last, last = self.linears[:-1], self.linears[-1]
        x = all_but_last(x)
        x = last(x)
        return x

# @torchdynamo.optimize('eager')
def toy_example(model, input):
    return model(input)

model = TestModel().cuda()
x = torch.rand((8, 100), device='cuda')

for _ in range(10):
    toy_example(model, x)

Its okay if the answer is, this is not supported but just wanted to raise it as fyi

voznesenskym commented 2 years ago

What graph break do you see? which pattern are you talking about? Can you run torchdynamo.explain() and add the break reasons to the issue?

ngoyal2707 commented 2 years ago

sorry didn't mean break graph, basically above code gives following error with dynamo :

    return model(input)
  File "examples/DALLE2/scripts/test_dynamo.py", line 18, in forward
    x = all_but_last(x)

Traceback (most recent call last):
  File "examples/DALLE2/scripts/test_dynamo.py", line 31, in <module>
    toy_example(model, x)
  File "/shared/home/namangoyal/src/torchdynamo/torchdynamo/eval_frame.py", line 163, in _fn
    return fn(*args, **kwargs)
  File "examples/DALLE2/scripts/test_dynamo.py", line 23, in toy_example
    @torchdynamo.optimize('eager')
  File "/shared/home/namangoyal/miniconda3/envs/ldm_ait/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1274, in _call_impl
    return forward_call(*input, **kwargs)
  File "examples/DALLE2/scripts/test_dynamo.py", line 16, in forward
    def forward(self, x):
TypeError: 'tuple' object is not callable

but works without it

voznesenskym commented 2 years ago

gotcha :) thanks - can you port this issue to pytorch? Just realized its in dynamo, and this repo is effectively closed.

ngoyal2707 commented 2 years ago

sure: closing this and opened https://github.com/pytorch/pytorch/issues/87121

soumith commented 2 years ago

@voznesenskym we're still using the dynamo issue tracker

voznesenskym commented 2 years ago

@voznesenskym we're still using the dynamo issue tracker

My mistake. Apologies @ngoyal2707

ngoyal2707 commented 2 years ago

lol alright, I closed the PT one

voznesenskym commented 2 years ago

https://github.com/pytorch/pytorch/pull/87156/

voznesenskym commented 2 years ago

Sorry it took so long, did not get around to it till an hour ago or so

voznesenskym commented 2 years ago

Forgot to close it, solved in pytorch. https://github.com/pytorch/pytorch/pull/87156