Infini-AI-Lab / TriForce

[COLM 2024] TriForce: Lossless Acceleration of Long Sequence Generation with Hierarchical Speculative Decoding
https://infini-ai-lab.github.io/TriForce/
144 stars 12 forks source link

Out of memory on H800 #7

Open Lucas-TY opened 3 weeks ago

Lucas-TY commented 3 weeks ago
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.18it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 124928
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
  File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
    graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 144, in initialize_cuda_graph
    self.callables[gamma_offset] = draft_run_capture_graph(
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 83, in draft_run_capture_graph
    static_logits = engine.draft_run(input_ids=static_input_ids, gamma_offset=gamma_offset, probs=probs, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 54, in draft_run
    logits = self.draft(input_ids=input_ids, kv_cache=self.draft_cache, graph_cache=self.draft_cache, gamma_offset=gamma_offset).logits
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 340, in forward
    outputs = self.model(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 301, in forward
    layer_outputs = decoder_layer(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 220, in forward
    hidden_states = self.self_attn(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 141, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 
preminstrel commented 3 weeks ago

Hello, may I ask the memory for your device? You can try to decrease prefill from 124928 to 122880 to see if it is still OOM. The code can run on a single A100-80GB.

Lucas-TY commented 3 weeks ago

122880 I got this error


CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096  --chunk_size 8 --t
op_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.85s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.82it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)... [draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)... Traceback (most recent call last): File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 154, in initialize_cuda_graph self.callable_model_verify = model_verify_capture_graph( File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 111, in model_verify_capture_graph static_logits = engine.model_verify(input_ids=static_input_ids, position_ids=static_position_ids, probs=probs, temperature=temperature, top_p=top_p) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 63, in model_verify logits = self.model(input_ids=input_ids, kv_cache=self.kv_cache, graph_cache=self.graph_cache, position_ids=position_ids, spec=True).logits File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 396, in forward outputs = self.model( File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 355, in forward layer_outputs = decoder_layer( File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 276, in forward hidden_states = self.self_attn( File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 222, in forward query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 182, in apply_rotary_pos_emb q_embed = (q cos) + (rotate_half(q) * sin) RuntimeError: The size of tensor a (32) must match the size of tensor b (131072) at non-singleton dimension 1

Lucas-TY commented 3 weeks ago
nvidia-smi
Wed Jun 26 10:40:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H800                    On  | 00000000:61:00.0 Off |                    0 |
| N/A   30C    P0              70W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
preminstrel commented 3 weeks ago

What is your transformers version? Can you set it to transformers==4.37.2 since apply_rotary_pos_emb api changes for recent versions?

Lucas-TY commented 3 weeks ago

thanks, I think it's more likely to be an environment problem

CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096  --chunk_size 8 
--top_p 0.9 --temp 0.6 --gamma 6
/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.15it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
[model verify] capturing graph for spec len 6 (probs=True, temp=0.6, top_p=0.9)...
[Full Cache] Cached: 0 | Budget: 123152
[Retrieval Cache] Budget: 4096  | PreFill: 122880  | Chunk Size: 8  | Chunks: 15360  | Select Sets: 512
[StreamingLLM Cache] Start Size: 16 | Recent Size: 234 | Gamma: 6 | Real Budget: 259 | Cached: 0
tokenized_prompts length: 20
Autoregressive Warmup: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.44s/it]
Autoregressive Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:35<00:00, 35.70s/it]
[Autoregressive] average latency: 31.812156550586224 ms
TriForce Warmup: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:44<00:00, 34.95s/it]
TriForce Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [11:08<00:00, 33.43s/it]
average acceptance rate (NOT per token): 0.6893210335795382
[TriForce] average latency: 17.238558956780214 ms
[E2E Speedup]: 1.8454069525384529
Lucas-TY commented 3 weeks ago

I changed the version of the transformer because there are other environment issues, such as:

preminstrel commented 3 weeks ago

Yeah I am using CUDA 12.1. Here is my flash_attn version.

>>> import torch
>>> torch.__version__
'2.2.1+cu121'
>>> import flash_attn
>>> flash_attn.__version__
'2.5.7'
preminstrel commented 3 weeks ago

I have added a FAQ in the README :)