pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.2k stars 22.44k forks source link

Minifier doesn't transfer execution states like @torch.no_grad to repro #111414

Open HDCharles opened 1 year ago

HDCharles commented 1 year ago

🐛 Describe the bug

if you use something like torch.no_grad or inference_mode in a script you are trying to minify, the minifier reproduces a repro that won't have such context included.

As an example why this is a problem: in any situation involving pattern matching, without torch.no_grad, the graph will try to record gradients making almost all the node values get passed to the output. If you have something like (node1, node2) -> fused_node1and2, if node1 has a gradient then the pattern won't match because it would destroy the output of node 1.

Error logs

n/a

Minified repro

if you run the minifier on:

import torch
with torch.no_grad():
    def do_an_int_mm(x):
        x = torch._int_mm(x, x)
        return x

    x = torch.randint(-128,127, (16,16), dtype=torch.int8)

    comp_fn = torch.compile(do_an_int_mm, mode='max-autotune')
    print(comp_fn(x).sum())

you will get the minified repro without torch.no_grad

from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_
        x = torch._int_mm(l_x_, l_x_);  l_x_ = None
        return (x,)

mod = Repro()

def load_args(reader):
    buf0 = reader.storage('dc3464694021bc768a42e2357907046d8d41f683', 256, dtype_hint=torch.int8)
    reader.tensor(buf0, (16, 16), dtype=torch.int8, is_leaf=True)  # L_x_
load_args._version = 0

if __name__ == '__main__':
    from torch._dynamo.repro.after_dynamo import run_repro
    run_repro(mod, load_args, accuracy=False, command='minify',
        save_dir='/home/USERNAME/local/torch_compile_debug/run_2023_10_16_20_00_56_432404-pid_2505787/minifier/checkpoints', autocast=False, backend='inductor')

Versions

Collecting environment information... PyTorch version: 2.2.0a0+git52e1478 Is debug build: False CUDA used to build PyTorch: 12.0 ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64) GCC version: (GCC) 11.4.1 20230605 (Red Hat 11.4.1-2) Clang version: 16.0.6 (Red Hat 16.0.6-1.el9) CMake version: version 3.27.0 Libc version: glibc-2.34

Python version: 3.10.12 (main, Jul 5 2023, 18:54:27) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.12.0-0_fbk16_zion_7661_geb00762ce6d2-x86_64-with-glibc2.34 Is CUDA available: True CUDA runtime version: 12.0.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA PG509-210 GPU 1: NVIDIA PG509-210 GPU 2: NVIDIA PG509-210 GPU 3: NVIDIA PG509-210 GPU 4: NVIDIA PG509-210 GPU 5: NVIDIA PG509-210 GPU 6: NVIDIA PG509-210 GPU 7: NVIDIA PG509-210

Nvidia driver version: 525.105.17 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.8.8.1 /usr/lib64/libcudnn_adv_infer.so.8.8.1 /usr/lib64/libcudnn_adv_train.so.8.8.1 /usr/lib64/libcudnn_cnn_infer.so.8.8.1 /usr/lib64/libcudnn_cnn_train.so.8.8.1 /usr/lib64/libcudnn_ops_infer.so.8.8.1 /usr/lib64/libcudnn_ops_train.so.8.8.1 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.0.5 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5 /usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 192 On-line CPU(s) list: 0-191 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz CPU family: 6 Model: 85 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 4 Stepping: 11 Frequency boost: enabled CPU max MHz: 1801.0000 CPU min MHz: 800.0000 BogoMIPS: 3600.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke avx512_vnni md_clear flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 3 MiB (96 instances) L1i cache: 3 MiB (96 instances) L2 cache: 96 MiB (96 instances) L3 cache: 132 MiB (4 instances) NUMA node(s): 4 NUMA node0 CPU(s): 0-23,96-119 NUMA node1 CPU(s): 24-47,120-143 NUMA node2 CPU(s): 48-71,144-167 NUMA node3 CPU(s): 72-95,168-191 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] flake8==6.0.0 [pip3] flake8-bugbear==23.3.23 [pip3] flake8-comprehensions==3.12.0 [pip3] flake8-executable==2.1.3 [pip3] flake8-logging-format==0.9.0 [pip3] flake8-pyi==23.3.1 [pip3] flake8-simplify==0.19.3 [pip3] mypy==1.4.1 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.24.3 [pip3] pytorch-lightning==2.0.6 [pip3] pytorch-triton==2.1.0+6e4932cda8 [pip3] torch==2.2.0a0+gita063238 [pip3] torchmetrics==1.0.2 [pip3] torchvision==0.17.0a0+ace9221 [pip3] triton-nightly==2.1.0.dev20230726014945 [conda] blas 1.0 mkl
[conda] magma-cuda117 2.6.1 1 pytorch [conda] mkl 2023.1.0 h6d00ec8_46342
[conda] mkl-include 2023.1.0 h06a4308_46342
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.6 py310h1128e8f_1
[conda] mkl_random 1.2.2 py310h1128e8f_1
[conda] numpy 1.24.3 py310h5f9d8c6_1
[conda] numpy-base 1.24.3 py310hb5e798b_1
[conda] pytorch-lightning 2.0.6 pypi_0 pypi [conda] pytorch-triton 2.1.0+6e4932cda8 pypi_0 pypi [conda] torch 2.2.0a0+gita063238 dev_0 [conda] torchmetrics 1.0.2 pypi_0 pypi [conda] torchvision 0.17.0a0+ace9221 dev_0 [conda] triton-nightly 2.1.0.dev20230726014945 pypi_0 pypi

cc @ezyang @anijain2305 @chauhang @penguinwu @msaroufim @bdhirsh @zou3519 @wconstab

jon-chuang commented 1 year ago

@ezyang Maybe we also want to serialize relevant global config like these to dynamo config?

We can use ConfigModule infra e.g. torch.config :thinking:

RE: https://github.com/pytorch/pytorch/issues/110682

ezyang commented 1 year ago

no_grad doesn't work this way, for one it's thread local, not global. It would be nice if we had a centralized place to look at all of torch's TLS in one place though.

jon-chuang commented 1 year ago

no_grad doesn't work this way, for one it's thread local, not global

Hmm, right, it's a TLS in C++: https://github.com/pytorch/pytorch/blob/fa995626a8e181e3666b27fdb4edbe6116b22ee3/c10/core/AutogradState.h#L13

Plus, I guess we capture it if we set it within a compile.

Still not sure what to do about repros.

ezyang commented 1 year ago

I think a good start is to enumerate the set of states we might want to capture, and then successively add setters for the state around the repro in question.

jon-chuang commented 1 year ago

Oh, I see, we do have a guard: ___check_global_state().

def fn(x):
    x = x.add(x)
    return x

comp_fn = None
with torch.no_grad():
    x = torch.randn((16,16), requires_grad=True)

    comp_fn = torch.compile(fn, dynamic=False)
    assert comp_fn(x).requires_grad == False  # False

with torch.enable_grad():
    x = torch.randn((16,16), requires_grad=True)
    assert fn(x).requires_grad == True  # False, as we use old artifact?

    x = torch.randn((16,16), requires_grad=True)  # shape change
    assert fn(x).requires_grad == True  # True, as we recompile under new config?
jon-chuang commented 1 year ago

Turns out this guard isn't thread-safe after all https://github.com/pytorch/pytorch/issues/111569 :smiling_face_with_tear: