pytorch / torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
BSD 3-Clause "New" or "Revised" License
3.35k stars 219 forks source link

int4_weight_only in Cuda compile := RuntimeError: _apply(): Couldn't swap Linear.weight #1125

Open Jack-Khuu opened 1 month ago

Jack-Khuu commented 1 month ago

🐛 Describe the bug

When generating multiple samples from a compiled int4 model on CUDA, a runtime error occurs relating to Linear.weight swapping:

Traceback (most recent call last):
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
    torch.utils.swap_tensors(param, param_applied)
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/utils/__init__.py", line 51, in swap_tensors
    raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jackkhuu/oss/torchchat/torchchat.py", line 83, in <module>
    generate_main(args)
  File "/home/jackkhuu/oss/torchchat/torchchat/generate.py", line 934, in main
    for _ in gen.chat(generator_args):
  File "/home/jackkhuu/oss/torchchat/torchchat/generate.py", line 826, in chat
    for token_tensor, metrics in generator_func:
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 36, in generator_context
    response = gen.send(None)
  File "/home/jackkhuu/oss/torchchat/torchchat/generate.py", line 518, in generate
    model = model.to(device=device)
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
    return self._apply(convert)
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/home/jackkhuu/.conda/envs/99/lib/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
    raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight

This only occurs for int4 Quant + CUDA (which narrows it to int4_weight_only) with torch.compile

Example Command:

python3 torchchat.py generate llama3.1 --quantize '{"linear:int4": {"groupsize": 256}, "executor":{"accelerator":"cuda"}}' --compile --num-samples 2

Commit: https://github.com/pytorch/torchchat/commit/16dbdd782ae9f0ec2ba53c764ded0b80030172a9

Versions

PyTorch version: 2.5.0.dev20240814+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64) GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3) Clang version: 18.1.6 (CentOS 18.1.6-3.el9) CMake version: version 3.30.3 Libc version: glibc-2.34

Python version: 3.10.0 (default, Mar 3 2022, 09:58:08) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.19.0-0_fbk12_hardened_11583_g0bef9520ca2b-x86_64-with-glibc2.34 Is CUDA available: True CUDA runtime version: 12.2.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA PG509-210 Nvidia driver version: 525.105.17 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.8.9.4 /usr/lib64/libcudnn_adv_infer.so.8.9.4 /usr/lib64/libcudnn_adv_train.so.8.9.4 /usr/lib64/libcudnn_cnn_infer.so.8.9.4 /usr/lib64/libcudnn_cnn_train.so.8.9.4 /usr/lib64/libcudnn_ops_infer.so.8.9.4 /usr/lib64/libcudnn_ops_train.so.8.9.4 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 22 On-line CPU(s) list: 0-21 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz CPU family: 6 Model: 85 Thread(s) per core: 1 Core(s) per socket: 22 Socket(s): 1 Stepping: 11 BogoMIPS: 3591.57 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat umip pku ospke avx512_vnni md_clear arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 704 KiB (22 instances) L1i cache: 704 KiB (22 instances) L2 cache: 88 MiB (22 instances) L3 cache: 16 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-21 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Vulnerable: eIBRS with unprivileged eBPF Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] pytorch-triton==3.0.0+dedb7bdf33 [pip3] torch==2.5.0.dev20240814+cu121 [pip3] torchao==0.4.0+git477ddb6 [conda] numpy 1.26.4 pypi_0 pypi [conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi [conda] torch 2.5.0.dev20240814+cu121 pypi_0 pypi [conda] torchao 0.4.0+git477ddb6 pypi_0 pypi

Jack-Khuu commented 1 month ago

@jerryzh168 Have you seen this before?

Jack-Khuu commented 1 month ago

Seems like model = model.to(device=device) doesn't play well the second time around. Maybe the cache being populated makes a difference with this quant?

https://github.com/pytorch/torchchat/blob/main/torchchat/generate.py#L518

Jack-Khuu commented 1 month ago

Looks like something that may have gotten picked up in the PT or AO pin bumps: https://github.com/pytorch/torchchat/commit/147c292dd2994be664ea415ea7ff580dcc1fdb3a

jerryzh168 commented 1 month ago

sorry just saw this issue, I haven't see the error before, we also have test for to(device="cuda") as well for int4_weight_only I think: https://github.com/pytorch/ao/blob/ceec750d7f6f8aefabf6c31e83f139be79ac03b4/test/dtypes/test_affine_quantized.py#L78

is this still an issue?