thuml / depyf

depyf is a tool to help you understand and adapt to PyTorch compiler torch.compile.
https://depyf.readthedocs.io
MIT License
434 stars 11 forks source link

[Bug]: nn.module return value changed unexpectedly under depyf.prepare_debug #36

Closed gameofdimension closed 2 months ago

gameofdimension commented 2 months ago

Your current environment


Collecting environment information...
PyTorch version: 2.2.0a0+81ea7a4

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-144-generic-x86_64-with-glibc2.35

Versions of relevant libraries:
[pip3] depyf==0.15.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.2.0a0+81ea7a4
[pip3] torch-tensorrt==2.2.0a0
[pip3] torchdata==0.7.0a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.17.0a0
[conda] Could not collect

🐛 Describe the bug

the return value of forward function will wrap its value into a list, which is unexpected.

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.norm1 = nn.GroupNorm(32, 128)
        self.act1 = nn.SiLU()
        self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.AvgPool2d(kernel_size=2)

        self.norm2 = nn.GroupNorm(32, 256)
        self.act2 = nn.SiLU()
        self.drop = nn.Dropout(0.1)
        self.conv2 = nn.Conv2d(256, 256, 3, padding=1)

    def forward(self, x):
        x = self.norm1(x)
        x = self.act1(x)
        x = self.conv1(x)
        x = self.pool(x)
        print("will trigger graph break, x.shape:", x.shape)
        x = self.norm2(x)
        x = self.act2(x)
        x = self.drop(x)
        x = self.conv2(x)
        out = (x*x).mean()
        return out

def main():
    device = 'cuda'
    model = Model().to(device)
    model = torch.compile(model=model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    enabled = True

    step = 0
    while True:
        data = torch.randn(64, 128, 64, 64, device=device)
        with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enabled):  # type: ignore # noqa
            loss = model(data)
            print("loss type", type(loss))
            if isinstance(loss, list):
                loss = loss[0]
            print("Loss:", loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        step += 1
        if step >= 10:
            break

def will_return_loss_list():
    import depyf
    with depyf.prepare_debug("depyf_debug_dir"):
        main()

def will_return_loss_tensor():
    main()

if __name__ == '__main__':
    will_return_loss_list()
    # will_return_loss_tensor()
youkaichao commented 2 months ago

Thanks for your report! it is indeed a bug. #37 should fix it ❤️

youkaichao commented 2 months ago

@gameofdimension one suggestion for your code: you create the module object inside the main function, and the module object will be released after the function returns, which will remove compiled code cache entry. In this case, you can only see the captured graph code, but cannot see the compiled cache/guard code stuff.

I recommend to make the model object persist after the function call, and then you will get much more information about torch.compile. e.g.:

import torch
from torch import nn

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.norm1 = nn.GroupNorm(32, 128)
        self.act1 = nn.SiLU()
        self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.AvgPool2d(kernel_size=2)

        self.norm2 = nn.GroupNorm(32, 256)
        self.act2 = nn.SiLU()
        self.drop = nn.Dropout(0.1)
        self.conv2 = nn.Conv2d(256, 256, 3, padding=1)

    def forward(self, x):
        x = self.norm1(x)
        x = self.act1(x)
        x = self.conv1(x)
        x = self.pool(x)
        print("will trigger graph break, x.shape:", x.shape)
        x = self.norm2(x)
        x = self.act2(x)
        x = self.drop(x)
        x = self.conv2(x)
        out = (x*x).mean()
        return out

def main(model):
    device = 'cpu'
    model.to(device)
    model = torch.compile(model=model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    enabled = True

    step = 0
    while True:
        data = torch.randn(64, 128, 64, 64, device=device)
        with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enabled):  # type: ignore # noqa
            loss = model(data)
            print("loss type", type(loss))
            if isinstance(loss, list):
                loss = loss[0]
            print("Loss:", loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        step += 1
        if step >= 10:
            break

if __name__ == '__main__':
    model = Model()
    import depyf
    with depyf.prepare_debug("depyf_debug_dir"):
        main(model)

Then you can get:


# Note: the following variables are used inside the guard function.
___check_global_state = '''<built-in method check of torch._C._dynamo.guards.GlobalStateGuard object at 0x16827e690>'''
___check_obj_id = '''<built-in function check_obj_id>'''
___check_tensors = '''<built-in method check of torch._C._dynamo.guards.TensorGuards object at 0x1680fca90>'''
tensor_check_names = '''["L['x']"]'''
utils_device = '''<module 'torch.utils._device' from '/Users/youkaichao/anaconda3/envs/py310/lib/python3.10/site-packages/torch/utils/_device.py'>'''
def __guard_0_for_torch_dynamo_resume_in_forward_at_23(L, G, **___kwargs_ignored):
    return (___check_global_state()) \
        and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
        and (___check_obj_id(L['self'], 4396891824)) \
        and (___check_obj_id(L['self'].training, 4380193976)) \
        and (utils_device.CURRENT_DEVICE == None) \
        and (___check_obj_id(L['self'].act2, 4396887840)) \
        and (___check_obj_id(L['self'].act2.training, 4380193976)) \
        and (___check_obj_id(L['self'].drop, 4396887696)) \
        and (___check_obj_id(L['self'].drop.training, 4380193976)) \
        and (___check_obj_id(L['self'].conv2, 4396887600)) \
        and (___check_obj_id(L['self'].conv2.training, 4380193976)) \
        and (___check_obj_id(L['self'].norm2, 4396890048)) \
        and (___check_obj_id(L['self'].norm2.training, 4380193976)) \
        and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))

# Note: please refer to the graph code in __compiled_fn_5*.py.
# Captured Graph: Dynamo generated graph (debuggable when using eager backend).
# Joint graph: joint forward+backward graph from aot autograd.
# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).
# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).
# AFTER XXX: graph processed by inductor (not debuggable).
def __compiled_fn_5(*args, **kwargs):
    pass

def __transformed_code_0_for_torch_dynamo_resume_in_forward_at_23(___stack0, self, x):
    out = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
    __temp_6, = __compiled_fn_5(x)
    return __temp_6

# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def __resume_at_50_3(___stack0, self, x):
    x = self.norm2(x)
    x = self.act2(x)
    x = self.drop(x)
    x = self.conv2(x)
    out = (x * x).mean()
    return out

def transformed___resume_at_50_3(___stack0, self, x):
    __local_dict = {"___stack0": ___stack0, "self": self, "x": x}
    __global_dict = globals()
    if __guard_0_for_torch_dynamo_resume_in_forward_at_23(__local_dict, __global_dict):
        return __transformed_code_0_for_torch_dynamo_resume_in_forward_at_23(___stack0, self, x)
    # Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
    return __resume_at_50_3(___stack0, self, x)

#============ end of __resume_at_50_3 ============#

# Note: the following variables are used inside the guard function.
___check_global_state = '''<built-in method check of torch._C._dynamo.guards.GlobalStateGuard object at 0x1669d2ab0>'''
___check_obj_id = '''<built-in function check_obj_id>'''
___check_tensors = '''<built-in method check of torch._C._dynamo.guards.TensorGuards object at 0x1680fc7b0>'''
tensor_check_names = '''["L['x']"]'''
utils_device = '''<module 'torch.utils._device' from '/Users/youkaichao/anaconda3/envs/py310/lib/python3.10/site-packages/torch/utils/_device.py'>'''
def __guard_0_for_forward(L, G, **___kwargs_ignored):
    return (___check_global_state()) \
        and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
        and (___check_obj_id(L['self'], 4396891824)) \
        and (___check_obj_id(L['self'].training, 4380193976)) \
        and (utils_device.CURRENT_DEVICE == None) \
        and (___check_obj_id(G['__builtins_dict___1']['print'], 4381267088)) \
        and (___check_obj_id(L['self'].act1, 4396891632)) \
        and (___check_obj_id(L['self'].act1.training, 4380193976)) \
        and (___check_obj_id(L['self'].pool, 4396891728)) \
        and (___check_obj_id(L['self'].pool.training, 4380193976)) \
        and (___check_obj_id(L['self'].conv1, 4396891680)) \
        and (___check_obj_id(L['self'].conv1.training, 4380193976)) \
        and (___check_obj_id(L['self'].norm1, 4396891776)) \
        and (___check_obj_id(L['self'].norm1.training, 4380193976)) \
        and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))

# Note: please refer to the graph code in __compiled_fn_2*.py.
# Captured Graph: Dynamo generated graph (debuggable when using eager backend).
# Joint graph: joint forward+backward graph from aot autograd.
# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).
# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).
# AFTER XXX: graph processed by inductor (not debuggable).
def __compiled_fn_2(*args, **kwargs):
    pass

def __transformed_code_0_for_forward(self, x):
    out = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
    graph_out_0 = __compiled_fn_2(x)
    x = graph_out_0[0]
    return __resume_at_50_3(__builtins_dict___1['print'](
        'will trigger graph break, x.shape:', __import_torch.Size((64, 256, 32,
        32))), self, x)

# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def forward(self, x):
    x = self.norm1(x)
    x = self.act1(x)
    x = self.conv1(x)
    x = self.pool(x)
    print('will trigger graph break, x.shape:', x.shape)
    x = self.norm2(x)
    x = self.act2(x)
    x = self.drop(x)
    x = self.conv2(x)
    out = (x * x).mean()
    return out

def transformed_forward(self, x):
    __local_dict = {"self": self, "x": x}
    __global_dict = globals()
    if __guard_0_for_forward(__local_dict, __global_dict):
        return __transformed_code_0_for_forward(self, x)
    # Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
    return forward(self, x)

#============ end of forward ============#