Closed gameofdimension closed 2 months ago
Thanks for your report! it is indeed a bug. #37 should fix it ❤️
@gameofdimension one suggestion for your code: you create the module object inside the main
function, and the module object will be released after the function returns, which will remove compiled code cache entry. In this case, you can only see the captured graph code, but cannot see the compiled cache/guard code stuff.
I recommend to make the model
object persist after the function call, and then you will get much more information about torch.compile
. e.g.:
import torch
from torch import nn
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.norm1 = nn.GroupNorm(32, 128)
self.act1 = nn.SiLU()
self.conv1 = nn.Conv2d(128, 256, 3, padding=1)
self.pool = nn.AvgPool2d(kernel_size=2)
self.norm2 = nn.GroupNorm(32, 256)
self.act2 = nn.SiLU()
self.drop = nn.Dropout(0.1)
self.conv2 = nn.Conv2d(256, 256, 3, padding=1)
def forward(self, x):
x = self.norm1(x)
x = self.act1(x)
x = self.conv1(x)
x = self.pool(x)
print("will trigger graph break, x.shape:", x.shape)
x = self.norm2(x)
x = self.act2(x)
x = self.drop(x)
x = self.conv2(x)
out = (x*x).mean()
return out
def main(model):
device = 'cpu'
model.to(device)
model = torch.compile(model=model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
enabled = True
step = 0
while True:
data = torch.randn(64, 128, 64, 64, device=device)
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=enabled): # type: ignore # noqa
loss = model(data)
print("loss type", type(loss))
if isinstance(loss, list):
loss = loss[0]
print("Loss:", loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 1
if step >= 10:
break
if __name__ == '__main__':
model = Model()
import depyf
with depyf.prepare_debug("depyf_debug_dir"):
main(model)
Then you can get:
# Note: the following variables are used inside the guard function.
___check_global_state = '''<built-in method check of torch._C._dynamo.guards.GlobalStateGuard object at 0x16827e690>'''
___check_obj_id = '''<built-in function check_obj_id>'''
___check_tensors = '''<built-in method check of torch._C._dynamo.guards.TensorGuards object at 0x1680fca90>'''
tensor_check_names = '''["L['x']"]'''
utils_device = '''<module 'torch.utils._device' from '/Users/youkaichao/anaconda3/envs/py310/lib/python3.10/site-packages/torch/utils/_device.py'>'''
def __guard_0_for_torch_dynamo_resume_in_forward_at_23(L, G, **___kwargs_ignored):
return (___check_global_state()) \
and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
and (___check_obj_id(L['self'], 4396891824)) \
and (___check_obj_id(L['self'].training, 4380193976)) \
and (utils_device.CURRENT_DEVICE == None) \
and (___check_obj_id(L['self'].act2, 4396887840)) \
and (___check_obj_id(L['self'].act2.training, 4380193976)) \
and (___check_obj_id(L['self'].drop, 4396887696)) \
and (___check_obj_id(L['self'].drop.training, 4380193976)) \
and (___check_obj_id(L['self'].conv2, 4396887600)) \
and (___check_obj_id(L['self'].conv2.training, 4380193976)) \
and (___check_obj_id(L['self'].norm2, 4396890048)) \
and (___check_obj_id(L['self'].norm2.training, 4380193976)) \
and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))
# Note: please refer to the graph code in __compiled_fn_5*.py.
# Captured Graph: Dynamo generated graph (debuggable when using eager backend).
# Joint graph: joint forward+backward graph from aot autograd.
# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).
# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).
# AFTER XXX: graph processed by inductor (not debuggable).
def __compiled_fn_5(*args, **kwargs):
pass
def __transformed_code_0_for_torch_dynamo_resume_in_forward_at_23(___stack0, self, x):
out = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
__temp_6, = __compiled_fn_5(x)
return __temp_6
# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def __resume_at_50_3(___stack0, self, x):
x = self.norm2(x)
x = self.act2(x)
x = self.drop(x)
x = self.conv2(x)
out = (x * x).mean()
return out
def transformed___resume_at_50_3(___stack0, self, x):
__local_dict = {"___stack0": ___stack0, "self": self, "x": x}
__global_dict = globals()
if __guard_0_for_torch_dynamo_resume_in_forward_at_23(__local_dict, __global_dict):
return __transformed_code_0_for_torch_dynamo_resume_in_forward_at_23(___stack0, self, x)
# Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
return __resume_at_50_3(___stack0, self, x)
#============ end of __resume_at_50_3 ============#
# Note: the following variables are used inside the guard function.
___check_global_state = '''<built-in method check of torch._C._dynamo.guards.GlobalStateGuard object at 0x1669d2ab0>'''
___check_obj_id = '''<built-in function check_obj_id>'''
___check_tensors = '''<built-in method check of torch._C._dynamo.guards.TensorGuards object at 0x1680fc7b0>'''
tensor_check_names = '''["L['x']"]'''
utils_device = '''<module 'torch.utils._device' from '/Users/youkaichao/anaconda3/envs/py310/lib/python3.10/site-packages/torch/utils/_device.py'>'''
def __guard_0_for_forward(L, G, **___kwargs_ignored):
return (___check_global_state()) \
and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
and (___check_obj_id(L['self'], 4396891824)) \
and (___check_obj_id(L['self'].training, 4380193976)) \
and (utils_device.CURRENT_DEVICE == None) \
and (___check_obj_id(G['__builtins_dict___1']['print'], 4381267088)) \
and (___check_obj_id(L['self'].act1, 4396891632)) \
and (___check_obj_id(L['self'].act1.training, 4380193976)) \
and (___check_obj_id(L['self'].pool, 4396891728)) \
and (___check_obj_id(L['self'].pool.training, 4380193976)) \
and (___check_obj_id(L['self'].conv1, 4396891680)) \
and (___check_obj_id(L['self'].conv1.training, 4380193976)) \
and (___check_obj_id(L['self'].norm1, 4396891776)) \
and (___check_obj_id(L['self'].norm1.training, 4380193976)) \
and (___check_tensors(L['x'], tensor_check_names=tensor_check_names))
# Note: please refer to the graph code in __compiled_fn_2*.py.
# Captured Graph: Dynamo generated graph (debuggable when using eager backend).
# Joint graph: joint forward+backward graph from aot autograd.
# Forward graph: forward graph from aot autograd (debuggable when using aot_eager backend).
# Backward graph: backward graph from aot autograd (debuggable when using aot_eager backend).
# AFTER XXX: graph processed by inductor (not debuggable).
def __compiled_fn_2(*args, **kwargs):
pass
def __transformed_code_0_for_forward(self, x):
out = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
graph_out_0 = __compiled_fn_2(x)
x = graph_out_0[0]
return __resume_at_50_3(__builtins_dict___1['print'](
'will trigger graph break, x.shape:', __import_torch.Size((64, 256, 32,
32))), self, x)
# Note: if there is a transformed version below, this function might well not be executed directly. Please check the transformed version if possible.
def forward(self, x):
x = self.norm1(x)
x = self.act1(x)
x = self.conv1(x)
x = self.pool(x)
print('will trigger graph break, x.shape:', x.shape)
x = self.norm2(x)
x = self.act2(x)
x = self.drop(x)
x = self.conv2(x)
out = (x * x).mean()
return out
def transformed_forward(self, x):
__local_dict = {"self": self, "x": x}
__global_dict = globals()
if __guard_0_for_forward(__local_dict, __global_dict):
return __transformed_code_0_for_forward(self, x)
# Note: this function might well not be executed directly. It might well be transformed again, i.e. adding one more guards and transformed code.
return forward(self, x)
#============ end of forward ============#
Your current environment
🐛 Describe the bug
the return value of forward function will wrap its value into a list, which is unexpected.