Closed xgbj closed 9 months ago
That's a very good question. Thanks for your interest and thanks for asking!
Let's take this simple example code from readme:
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
main()
When you run TORCH_COMPILE_DEBUG=1 python test.py
, you will get a directory named torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520
. Inside the directory:
.
├── torchdynamo
│ └── debug.log
└── torchinductor
├── aot_model___0_debug.log
├── aot_model___10_debug.log
├── aot_model___11_debug.log
├── model__4_inference_10.1
│ ├── fx_graph_readable.py
│ ├── fx_graph_runnable.py
│ ├── fx_graph_transformed.py
│ ├── ir_post_fusion.txt
│ ├── ir_pre_fusion.txt
│ └── output_code.py
├── model__5_inference_11.2
│ ├── fx_graph_readable.py
│ ├── fx_graph_runnable.py
│ ├── fx_graph_transformed.py
│ ├── ir_post_fusion.txt
│ ├── ir_pre_fusion.txt
│ └── output_code.py
└── model___9.0
├── fx_graph_readable.py
├── fx_graph_runnable.py
├── fx_graph_transformed.py
├── ir_post_fusion.txt
├── ir_pre_fusion.txt
└── output_code.py
When you use depyf
with the following code:
# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List
@torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def main():
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
if __name__ == "__main__":
import depyf
with depyf.prepare_debug("depyf_debug_dir"):
main()
After running python test.py
, you get a directory depyf_debug_dir
, under which are these files:
.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py
So what's the difference?
TORCH_COMPILE_DEBUG
, the torchdynamo/debug.log
is long and difficult to understand. depyf
helps to decompile the bytecode into readable source code in __transformed_code_xx.py
file.TORCH_COMPILE_DEBUG
, the torchinductor/model__5_inference_11.2
etc names are very difficult to understand. Users have to manually figure out which function inside torchdynamo/debug.log
corresponds to which directory. Meanwhile, in depyf
, __compiled_fn_0
and other functions have exactly the same names as they appear in torchdynamo/debug.log
.In summary, depyf
is a much more improved version of debugging information for torch.compile
.
Hope it helps :)
Thank you for your patient explanation. It seems that depyf has provided us with a more convenient solution. I can't wait to use it in the upcoming projects.
Feel free to provide any feedback during your usage. Would be happy to help improve your experience with torch.compile
😄
Hello, I found that the TORCH_COMPILE_DEBUG parameter can also generate graph code and corresponding triton code for debug. What are the differences between depyf and TORCH_COMPILE_DEBUG ? In what situations should I refer to using depyf ?