ELS-RD / kernl

Kernl lets you run PyTorch transformer models several times faster on GPU with a single line of code, and is designed to be easily hackable.
http://www.kernl.ai
Apache License 2.0
1.53k stars 95 forks source link

bug: Llama model optimization failing #317

Closed AndrewMead10 closed 1 year ago

AndrewMead10 commented 1 year ago

Description

When trying to use kernl with Llama 7B, I get an error when capturing the graph.

Steps to reproduce

from kernl.model_optimization import optimize_model
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch._dynamo as torchdynamo
import time

torchdynamo.config.cache_size_limit = 512

tokenizer = AutoTokenizer.from_pretrained("path/to/weights")

model = AutoModelForCausalLM.from_pretrained(
    "path/to/weights", torch_dtype=torch.float16, use_cache=True, load_in_8bit=False, use_auth_token=True)
model = model.eval().cuda()

text = "Hi how are you doing today?"

input_ids = tokenizer(text, return_tensors="pt",
                      pad_to_multiple_of=64).input_ids

input_ids = input_ids.cuda()

def warmup(model, input_ids, length):
    start = time.perf_counter()
    with torch.inference_mode():
        for i in range(10):
            model.generate(input_ids, max_length=length, min_length=length)
        torch.cuda.synchronize()

    return time.perf_counter() - start

optimized128 = warmup(model, input_ids, 128)

Expected Behavior

An optimized llama model

Actual Behavior

/opt/conda/lib/python3.9/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGra$
h.cpp:191.)                                                                                                                                                                                                                                             
  super().capture_end()                                                                                                                                                                                                                                 
/opt/conda/lib/python3.9/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGra$
h.cpp:191.)                                                                                                                                                                                                                                             
  super().capture_end()                                                                                                                                                                                                                                 
/opt/conda/lib/python3.9/site-packages/torch/cuda/graphs.py:79: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGra$
h.cpp:191.)                                                                                                                                                                                                                                             
  super().capture_end()                                                                                                                                                                                                                                 
Traceback (most recent call last):                                                                                                                                                                                                                      
  File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 271, in __call__                                                                                                                                                         
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]                                                                                                                                                                         
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                
  File "<eval_with_key>.149", line 5, in forward                                                                                                                                                                                                        
    to = _stack0.to(device(type='cuda', index=1));  _stack0 = None                                                                                                                                                                                      
RuntimeError: CUDA error: operation not permitted when stream is capturing                                                                                                                                                                              
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.                                                                                                                                                                                     

Call using an FX-traced Module, line 5 of the traced Module's generated forward function:                                                                                                                                                               
def forward(self, _stack0 : torch.Tensor, attention_mask : torch.Tensor):                                                                                                                                                                               
    to = _stack0.to(device(type='cuda', index=1));  _stack0 = None                                                                                                                                                                                      

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getitem = attention_mask[(slice(None, None, None), None, None, slice(None, None, None))];  attention_mask = None

    expand = getitem.expand(1, 1, 8, 8);  getitem = None

Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 328, in cudagraphify_impl
    static_outputs = model(list(static_inputs))
  File "/opt/conda/lib/python3.9/site-packages/kernl/optimizer/cuda_graph.py", line 130, in <lambda>
    model=lambda args: model(*args), inputs=new_inputs, static_input_idxs=tuple(range(len(inputs)))
  File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 662, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/fx/graph_module.py", line 279, in __call__
    raise e.with_traceback(None)
RuntimeError: CUDA error: operation not permitted when stream is capturing
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Your environment

Self-service

Code of Conduct

jonathlela commented 1 year ago

Hi @AndrewMead10 ,

Thank you for reporting this bug. I was able to reproduce and identify two issues, one concerning our use of cudagraph and another on our kernels tritons. We'll investigate more in the near future.

hemildesai commented 1 year ago

For now, you can just remove the cuda graphs utility. Faced a similar issue for GPT2 from Transformers. I got around it by changing return cuda_graphs_wrapper(gm, example_inputs) to return gm here. Although, didn't get much speedup after doing this.

jonathlela commented 1 year ago

Yes, optimization is very long (more than 1h) and no benefits after. However, I think we'll come with another optimizations in the future to speed up llama models.

CorentinJ commented 1 year ago

For now, you can just remove the cuda graphs utility. Faced a similar issue for GPT2 from Transformers. I got around it by changing return cuda_graphs_wrapper(gm, example_inputs) to return gm here. Although, didn't get much speedup after doing this.

Meeting the same issue and result for GPT2. Is there any update on this problem?