pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.13k stars 22.42k forks source link

torch.compile(...) does not pass models containing RNN layers to custom compilers #116858

Open mergian opened 9 months ago

mergian commented 9 months ago

🐛 Describe the bug

torch.compile(...) does not pass models containing RNN layers to custom compilers, see this example:

import torch

class Model1(torch.nn.Module):
        def __init__(self):
                super().__init__()
                self.rnn = torch.nn.RNN(3, 3, 3)

        def forward(self, x):
                return self.rnn(x)

class Model2(torch.nn.Module):
        def forward(self, x):
                return x*x

input = torch.rand(1, 2, 3)
model1 = Model1()
model2 = Model2()

def compiler(gm, example_inputs):
        print("## --- compiler called -- ##")
        print(gm)
        return gm

cmodel1 = torch.compile(model1, backend=compiler)
cmodel2 = torch.compile(model2, backend=compiler)

with torch.no_grad():
        print("# ---- Running Model1 ---- #")
        cmodel1(input)

        print("")
        print("# ---- Running Model2 ---- #")
        cmodel2(input)

Generates the following output.

# ---- Running Model1 ---- #

# ---- Running Model2 ---- #
## --- compiler called -- ##
GraphModule()

def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    mul = l_x_ * l_x_;  l_x_ = None
    return (mul,)

# To see more debug info, please use `graph_module.print_readable()`

Output shows that the custom compiler does not get called with the RNN model, but works fine for the basic x*x model.

Versions

Collecting environment information... PyTorch version: 2.1.2+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64) GCC version: (GCC) 10.5.0 Clang version: Could not collect CMake version: version 3.28.1 Libc version: glibc-2.17

Python version: 3.8.18 (default, Nov 21 2023, 09:31:57) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)] (64-bit runtime) Python platform: Linux-3.10.0-1160.102.1.el7.x86_64-x86_64-with-glibc2.17 Is CUDA available: True CUDA runtime version: 12.3.103 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Quadro P4000 Nvidia driver version: 535.146.02 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 24 On-line CPU(s) list: 0-23 Thread(s) per core: 2 Core(s) per socket: 12 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Gold 6126 CPU @ 2.60GHz Stepping: 4 CPU MHz: 3233.972 CPU max MHz: 3700.0000 CPU min MHz: 1000.0000 BogoMIPS: 5200.00 Virtualization: VT-x L1d cache: 32K L1i cache: 32K L2 cache: 1024K L3 cache: 19712K NUMA node0 CPU(s): 0-23 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba rsb_ctxsw ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke md_clear spec_ctrl intel_stibp flush_l1d arch_capabilities

Versions of relevant libraries: [pip3] numpy==1.24.3 [pip3] torch==2.1.2 [pip3] torchmetrics==0.11.4 [pip3] torchvision==0.16.2 [pip3] triton==2.1.0 [conda] Could not collect

cc @zou3519 @mikaylagawarecki @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @wconstab

ezyang commented 9 months ago

That's because we don't support capturing RNN right now, it is graph breaking