Open Lucas-TY opened 4 months ago
Hello, may I ask the memory for your device? You can try to decrease prefill from 124928 to 122880 to see if it is still OOM. The code can run on a single A100-80GB.
122880 I got this error
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096 --chunk_size 8 --t op_p 0.9 --temp 0.6 --gamma 6 Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.85s/it] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00, 2.82it/s] ####################################### Config ####################################### Method: TriForce Dataset: gs Spec Args: {'budget': 4096, 'chunk_size': 8} Draft: JackFram/llama-68m Target: NousResearch/Yarn-Llama-2-7b-128k Prefill Length: 122880 Generation Length: 256 Gamma: 6 Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6 Log CSV: None ######################################################################################
[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in
nvidia-smi
Wed Jun 26 10:40:05 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H800 On | 00000000:61:00.0 Off | 0 |
| N/A 30C P0 70W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
What is your transformers
version? Can you set it to transformers==4.37.2
since apply_rotary_pos_emb
api changes for recent versions?
thanks, I think it's more likely to be an environment problem
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096 --chunk_size 8
--top_p 0.9 --temp 0.6 --gamma 6
/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.65s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00, 3.15it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################
[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
[model verify] capturing graph for spec len 6 (probs=True, temp=0.6, top_p=0.9)...
[Full Cache] Cached: 0 | Budget: 123152
[Retrieval Cache] Budget: 4096 | PreFill: 122880 | Chunk Size: 8 | Chunks: 15360 | Select Sets: 512
[StreamingLLM Cache] Start Size: 16 | Recent Size: 234 | Gamma: 6 | Real Budget: 259 | Cached: 0
tokenized_prompts length: 20
Autoregressive Warmup: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.44s/it]
Autoregressive Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:35<00:00, 35.70s/it]
[Autoregressive] average latency: 31.812156550586224 ms
TriForce Warmup: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:44<00:00, 34.95s/it]
TriForce Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [11:08<00:00, 33.43s/it]
average acceptance rate (NOT per token): 0.6893210335795382
[TriForce] average latency: 17.238558956780214 ms
[E2E Speedup]: 1.8454069525384529
I changed the version of the transformer because there are other environment issues, such as:
Yeah I am using CUDA 12.1. Here is my flash_attn
version.
>>> import torch
>>> torch.__version__
'2.2.1+cu121'
>>> import flash_attn
>>> flash_attn.__version__
'2.5.7'
I have added a FAQ in the README :)