Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.18k stars 77 forks source link

Empty autocast regions are preserved by Dynamo and given to compilers to optimize #1342

Open IvanYashchuk opened 1 week ago

IvanYashchuk commented 1 week ago

🐛 Bug

Dynamo may create empty graphs where we do redundant work if we use a normal compilation pipeline. Here's an example of an empty function with just autocast region applied on an empty block of code. I think we should add a step before the splitter to remove empty autocast regions, @kshitij12345 what are your thoughts on this?

@tfogal discovered a similar empty graph in his NeMo NeVA investigations.

In [1]: from thunder.dynamo import ThunderCompiler

In [2]: import torch

In [3]: def f():
   ...:     with torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
   ...:         pass
   ...:     return
   ...: 

In [4]: backend = ThunderCompiler()
/home/iyashchuk/dev/lightning-thunder/thunder/dynamo/compiler.py:19: UserWarning: The ThunderCompiler is in active development and may not work as expected. Please report any issues you encounter to the Lightning Thunder team.
  warnings.warn(

In [5]: jf = torch.compile(backend=backend)(f)

In [6]: jf()

In [7]: backend.subgraph_infos[0].original_graph_module.print_readable()
class GraphModule(torch.nn.Module):
    def forward(self):
        # No stacktrace found for following nodes
        _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
        _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = None
        return ()

In [10]: print(backend.subgraph_infos[0].split_graph_module.graph)
graph():
    %inductor_1 : [num_users=0] = call_module[target=inductor_1](args = (), kwargs = {})
    return ()

In [11]: print(backend.subgraph_infos[0].split_graph_module.inductor_1.graph)
graph():
    %_enter_autocast : [num_users=1] = call_function[target=torch.amp.autocast_mode._enter_autocast](args = (cuda, torch.bfloat16, True, None), kwargs = {})
    %_exit_autocast : [num_users=0] = call_function[target=torch.amp.autocast_mode._exit_autocast](args = (%_enter_autocast,), kwargs = {})
kshitij12345 commented 2 days ago

I think we should add a step before the splitter to remove empty autocast regions, @kshitij12345 what are your thoughts on this?

Yes, we should do that. Even, inductor optimizes this away.