pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.58k stars 509 forks source link

index out of bounds for --compile_prefill with int4 and int8 #14

Open lopuhin opened 10 months ago

lopuhin commented 10 months ago

Running 13b chat model on L4 GPU with

python generate.py --checkpoint_path .../model_int4.g32.pth --compile --compile_prefill

An error happens

Traceback (most recent call last):                                                                                                                                                                                                                                            
  File "/home/user/gpt-fast/generate.py", line 407, in <module>                                                                                                                                                                                         
    main(                                                                                                                                                                                                                                                                     
  File "/home/user/gpt-fast/generate.py", line 346, in main                                                                                                                                                                                             
    y, metrics = generate(                                                                                                                                                                                                                                                    
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                                                                                                                
    return func(*args, **kwargs)                                                                                                                                                                                                                                              
  File "/home/user/gpt-fast/generate.py", line 190, in generate                                                                                                                                                                                         
    generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)                                                                                                                                 
  File "/home/user/gpt-fast/generate.py", line 62, in decode_n_tokens                                                                                                                                                                                   
    next_token, next_prob = decode_one_token(                                                                                                                                                                                                                                 
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn                                                                                                                                            
    return fn(*args, **kwargs)                                                                                                                                                                                                                                                
  File "/home/user/gpt-fast/generate.py", line 52, in decode_one_token                                                                                                                                                                                  
    def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:                                                                                                                               
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn                                                                                                                                            
    return fn(*args, **kwargs)                                                                                                                                                                                                                                                
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner                                                                                                                                       
    return fn(*args, **kwargs)                                                                                                                                                                                                                                                
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 899, in forward                                                                                                                                   
    return compiled_fn(full_args)                                                                                                                                                                                                                                             
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g                                                                                                                                   
    return f(*args)                                                                                                                                                                                                                                                           
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94, in runtime_wrapper                                                                                                          
    all_outs = call_func_at_runtime_with_args(                                                                                                                                                                                                                                
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args                                                                                                     
    out = normalize_as_list(f(args))                                                                                                                                                                                                                                          
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper                                                                               
    return compiled_fw(args)                                                                                                                                                                                                                                                  
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 861, in __call__                                                                                                                                      
    return self.get_current_callable()(inputs)                                                                                                                                                                                                                                
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 665, in run                                                                                                                                          
    return compiled_fn(new_inputs)                                                                                                                                                                                                                                            
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 380, in deferred_cudagraphify                                                                                                                   
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)                                                                                                                                                                                             
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 408, in cudagraphify
    return manager.add_function(                                                                                                       
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1941, in add_function
    return fn, fn(inputs)                                                                                                              
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1755, in run
    out = self._run(new_inputs, function_id)                                                                                           
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1796, in _run
    return self.run_eager(new_inputs, function_id)                                                                                     
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1911, in run_eager
    return node.run(new_inputs)                                                                                                        
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 611, in run
    out = self.wrapped_function.model(new_inputs)                                                                                      
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 889, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)                                                                                    
  File "/var/tmp/torchinductor_konstantin_scrapinghub_com/vm/cvm4fmeqocm6isnl6ojvwbbu2fgaupg577dne7od5jgoyyzjij4y.py", line 2469, in call
    triton_poi_fused_mul_silu_6.run(buf201, buf200, 13824, grid=grid(13824), stream=stream0)                    
  File "/home/user/gpt-fast/venv/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 568, in run
    return launcher(                                                                                                                   
  File "<string>", line 8, in launcher                                                                                                 
RuntimeError: Triton Error [CUDA]: device-side assert triggered                                                                        
unknown:0: unknown: block: [0,0,0], thread: [128,0,0] Assertion `index out of bounds: 0 <= tmp4 < 32000` failed.
unknown:0: unknown: block: [0,0,0], thread: [129,0,0] Assertion `index out of bounds: 0 <= tmp4 < 32000` failed.
unknown:0: unknown: block: [0,0,0], thread: [130,0,0] Assertion `index out of bounds: 0 <= tmp4 < 32000` failed.
unknown:0: unknown: block: [0,0,0], thread: [131,0,0] Assertion `index out of bounds: 0 <= tmp4 < 32000` failed.
...

Library versions:

pytorch-triton==2.1.0+6e4932cda8
torch==2.2.0.dev20231201+cu121

It works fine without --compile_prefill

lopuhin commented 10 months ago

I'm not yet sure if this might be related to having an extra pad token in the model -- but if I remove it the error is still there and all the weights have original shape. HF is misbehaving so can't download the unmodified model right now.

lopuhin commented 10 months ago

Reopening as this does not look related to the custom tokenizer -- I can reproduce the issue on the original llama-2-13b-chat-hf model.

lopuhin commented 10 months ago

A possible fix inspired by https://github.com/Lightning-AI/lit-gpt/issues/774 is to add a clone() call (committed in https://github.com/pytorch-labs/gpt-fast/commit/636cd767f0fa4d0e10ad456b67219a809f906dc2)

diff --git a/generate.py b/generate.py
index 7f30de0..cb4d7e6 100644
--- a/generate.py
+++ b/generate.py
@@ -161,7 +161,7 @@ def generate(
     seq = empty
     input_pos = torch.arange(0, T, device=device)

-    next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
+    next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
     if is_speculative:
         prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)

This removes the error, but there is no speedup for prompt of 708 tokens when doing --compile_prefill, but it's clear that prefill slow-down is noticeable compared to a shorter prompt.