pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.17k stars 22.43k forks source link

[OOM] Unable to convert 30B model to ONNX, using 4x A100's #103089

Open aleph65 opened 1 year ago

aleph65 commented 1 year ago

🐛 Describe the bug

Unable to convert 30B model to ONNX. I am using 4x A100's , 500GB RAM, 2.5TB Memory, still running out of memory.

image

Here's the repro:

I believe this is reproable in any container, but here's the container setup step:

1) Create a container on Runpod from winglian/axolotl-runpod:main-py3.9-cu118-2.0.0

Then deploy 4x A100 in Secure cloud, search for the Template just created:

image

2) Once it loads, start the terminal and:

mkdir tmp && ln -s /workspace/tmp /tmp
pip install optimum && pip install onnx && pip install onnxruntime-gpu
git lfs install
git clone https://huggingface.co/ehartford/WizardLM-30B-Uncensored

3) Paste the following inference file using vim:

touch fp16_to_onnx.py
vim fp16_to_onnx.py

Paste this:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from optimum.onnxruntime import ORTModelForCausalLM
import argparse
import os

parser = argparse.ArgumentParser(description="Convert fp16 model to onnx")
parser.add_argument("model_dir", type=str, help="fp16 model folder")
parser.add_argument("--device", type=str, default="cuda:0", help="device")

args = parser.parse_args()

model_dir = args.model_dir

device = torch.device("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_dir)

save_directory = "onnx_wiz/"
print("Loading")
ort_model = ORTModelForCausalLM.from_pretrained(
    model_dir, export=True).to(device)

print("Saving")
ort_model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

To exit vim, Esc -> Shift + Z -> Shift + Z

4) Now, run the conversion:

python fp16_to_onnx.py WizardLM-30B-Uncensored

This will take about 45 minutes, which already sounds a bit wrong as it should take 5m. gpt2 takes 30 seconds to convert.

Then , it will fail with this:

image

Can you please help unblock? I have been trying to convert this to ONNX for days already

Many thanks

Versions

CPU min MHz:                     1500.0000
CPU min MHz:                     1500.0000
BogoMIPS:                        5600.16
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxs
r_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe p
opcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core
perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rd
t_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoin
vd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmul
qdq rdpid overflow_recov succor smca
Virtualization:                  AMD-V
L1d cache:                       2 MiB (64 instances)
L1i cache:                       2 MiB (64 instances)
L2 cache:                        32 MiB (64 instances)
L3 cache:                        512 MiB (16 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-31,64-95
NUMA node1 CPU(s):               32-63,96-127
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1+cu118
[pip3] torchvision==0.15.1+cu118
[pip3] triton==2.0.0
[conda] No relevant packages
MaanavD commented 6 months ago

Hey @aleph65 ,

Is this still an issue you're interested in solving? If so, you could try using the torch.onnx.dynamo_export() api to attempt to export, and if that doesn't work, you could also try fake mode.

Let me know if you have any success! Will otherwise close as stale in 2-3 weeks :)