ELS-RD / kernl

Kernl lets you run PyTorch transformer models several times faster on GPU with a single line of code, and is designed to be easily hackable.
http://www.kernl.ai
Apache License 2.0
1.52k stars 93 forks source link

bug: Could not get kernl running on CodeT5 #283

Closed TheSeamau5 closed 1 year ago

TheSeamau5 commented 1 year ago

Description

I tried to call optimize_model on the CodeT5 model: https://huggingface.co/Salesforce/codet5-large-ntp-py

Instead, the call to model.generate hangs.

Steps to reproduce

Code to reproduce

# Standard Library imports
import os
import sys
import time

# Third-party Library imports
from rich import print
from rich.markdown import Markdown
from rich.syntax import Syntax
import torch
import torch._dynamo as torchdynamo
from transformers import AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration
from kernl.model_optimization import optimize_model

# Model Name
MODEL_NAME = "Salesforce/codet5-large-ntp-py"

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Download tokenizer from HuggingFace
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Download raw model from HuggingFace
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).eval().cuda()

# default cache size needs to be increased to store the many graphs with generative models
# torchdynamo.config.cache_size_limit = 512

# Optimize the model with kernl
optimize_model(model.encoder)
optimize_model(model.decoder)

input_prompt = """
# List all currently running unencrypted EC2 instances
# An unencrypted EC2 instance is an instance with at least one unencrypted EBS volume
# Step 1: List all unencrypted EBS volumes and get the list of attached EC2 instances
# Step 2: Return the set of unique EC2 instances with at least one unencrypted EBS volume
import boto3
""".strip()

# Tokenize the prompt 
inputs = tokenizer(
    input_prompt, 
    return_tensors="pt", 
    # pad_to_multiple_of=8, 
    # padding=True
).to(device)

# This is probably a hack
# os.environ["TOKENIZERS_PARALLELISM"] = "false"

# with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
with torch.inference_mode(), torch.cuda.amp.autocast():
    torch.cuda.synchronize()
    start_time = time.perf_counter()

    # Generate the completion
    print("Generate completion")
    model_output = model.generate(
        **inputs, 
        # max_length=512
        min_length=22, max_length=22
    )
    print("Completion generated")
    torch.cuda.synchronize()
    completion = tokenizer.decode(model_output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

    elapsed_time = time.perf_counter() - start_time
    print(f"Time Elapsed: {round(elapsed_time, 2)}s")
    print(Markdown("## Generated Code"))
    print(Syntax(input_prompt + completion, "python", theme="xcode", line_numbers=True))

Expected Behavior

If you comment out the optimize_model lines, you get an answer, which is expected

Generate completion
Completion generated
Time Elapsed: 1.1s
                                                  Generated Code                                                   
  1 # List all currently running unencrypted EC2 instances                                                         
  2 # An unencrypted EC2 instance is an instance with at least one unencrypted EBS volume                          
  3 # Step 1: List all unencrypted EBS volumes and get the list of attached EC2 instances                          
  4 # Step 2: Return the set of unique EC2 instances with at least one unencrypted EBS volume                      
  5 import boto3                                                                                                   
  6                                                                                                                
  7 ec2 = boto3.resource('ec2')                                                                                    
  8                                                                                                                
  9 # Step 3: List                                   

Actual Behavior

Error after keyboard interrupt

/home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch/cuda/graphs.py:82: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super(CUDAGraph, self).capture_end()
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>                                                                                      │
│                                                                                                  │
│   19 │                                                                                           │
│   20 │   # Generate the completion                                                               │
│   21 │   print("Generate completion")                                                            │
│ ❱ 22 │   model_output = model.generate(                                                          │
│   23 │   │   **inputs,                                                                           │
│   24 │   │   # max_length=512                                                                    │
│   25 │   │   min_length=22, max_length=22                                                        │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /utils/_contextlib.py:115 in decorate_context                                                    │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trans │
│ formers/generation/utils.py:1391 in generate                                                     │
│                                                                                                  │
│   1388 │   │   │   │   )                                                                         │
│   1389 │   │   │                                                                                 │
│   1390 │   │   │   # 11. run greedy search                                                       │
│ ❱ 1391 │   │   │   return self.greedy_search(                                                    │
│   1392 │   │   │   │   input_ids,                                                                │
│   1393 │   │   │   │   logits_processor=logits_processor,                                        │
│   1394 │   │   │   │   stopping_criteria=stopping_criteria,                                      │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trans │
│ formers/generation/utils.py:2179 in greedy_search                                                │
│                                                                                                  │
│   2176 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │
│   2177 │   │   │                                                                                 │
│   2178 │   │   │   # forward pass to get next token                                              │
│ ❱ 2179 │   │   │   outputs = self(                                                               │
│   2180 │   │   │   │   **model_inputs,                                                           │
│   2181 │   │   │   │   return_dict=True,                                                         │
│   2182 │   │   │   │   output_attentions=output_attentions,                                      │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /nn/modules/module.py:1488 in _call_impl                                                         │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trans │
│ formers/models/t5/modeling_t5.py:1663 in forward                                                 │
│                                                                                                  │
│   1660 │   │   │   │   decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_de  │
│   1661 │   │                                                                                     │
│   1662 │   │   # Decode                                                                          │
│ ❱ 1663 │   │   decoder_outputs = self.decoder(                                                   │
│   1664 │   │   │   input_ids=decoder_input_ids,                                                  │
│   1665 │   │   │   attention_mask=decoder_attention_mask,                                        │
│   1666 │   │   │   inputs_embeds=decoder_inputs_embeds,                                          │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /nn/modules/module.py:1488 in _call_impl                                                         │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /_dynamo/eval_frame.py:211 in _fn                                                                │
│                                                                                                  │
│   208 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic)                                     │
│   209 │   │   │   dynamic_ctx.__enter__()                                                        │
│   210 │   │   │   try:                                                                           │
│ ❱ 211 │   │   │   │   return fn(*args, **kwargs)                                                 │
│   212 │   │   │   finally:                                                                       │
│   213 │   │   │   │   set_eval_frame(prior)                                                      │
│   214 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                     │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/kernl │
│ /model_optimization.py:64 in run                                                                 │
│                                                                                                  │
│   61 │                                                                                           │
│   62 │   @torchdynamo.optimize(_compiler)                                                        │
│   63 │   def run(*args, **kwargs):                                                               │
│ ❱ 64 │   │   return model.forward_original(*args, **kwargs)                                      │
│   65 │                                                                                           │
│   66 │   model.forward = run                                                                     │
│   67                                                                                             │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trans │
│ formers/models/t5/modeling_t5.py:980 in forward                                                  │
│                                                                                                  │
│    977 │   │                                                                                     │
│    978 │   │   # We can provide a self-attention mask of dimensions [batch_size, from_seq_lengt  │
│    979 │   │   # ourselves in which case we just need to make it broadcastable to all heads.     │
│ ❱  980 │   │   extended_attention_mask = self.get_extended_attention_mask(attention_mask, input  │
│    981 │   │                                                                                     │
│    982 │   │   # If a 2D or 3D attention mask is provided for the cross-attention                │
│    983 │   │   # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_lengt  │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trans │
│ formers/models/t5/modeling_t5.py:989 in <graph break in forward>                                 │
│                                                                                                  │
│    986 │   │   │   encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)          │
│    987 │   │   │   if encoder_attention_mask is None:                                            │
│    988 │   │   │   │   encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_  │
│ ❱  989 │   │   │   encoder_extended_attention_mask = self.invert_attention_mask(encoder_attenti  │
│    990 │   │   else:                                                                             │
│    991 │   │   │   encoder_extended_attention_mask = None                                        │
│    992                                                                                           │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trans │
│ formers/models/t5/modeling_t5.py:989 in <graph break in forward>                                 │
│                                                                                                  │
│    986 │   │   │   encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)          │
│    987 │   │   │   if encoder_attention_mask is None:                                            │
│    988 │   │   │   │   encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_  │
│ ❱  989 │   │   │   encoder_extended_attention_mask = self.invert_attention_mask(encoder_attenti  │
│    990 │   │   else:                                                                             │
│    991 │   │   │   encoder_extended_attention_mask = None                                        │
│    992                                                                                           │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /_dynamo/eval_frame.py:211 in _fn                                                                │
│                                                                                                  │
│   208 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic)                                     │
│   209 │   │   │   dynamic_ctx.__enter__()                                                        │
│   210 │   │   │   try:                                                                           │
│ ❱ 211 │   │   │   │   return fn(*args, **kwargs)                                                 │
│   212 │   │   │   finally:                                                                       │
│   213 │   │   │   │   set_eval_frame(prior)                                                      │
│   214 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                     │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/kernl │
│ /optimizer/cuda_graph.py:128 in run                                                              │
│                                                                                                  │
│   125 │   │   nonlocal compiled_fn                                                               │
│   126 │   │   if compiled_fn is None:                                                            │
│   127 │   │   │   with dynamo_utils.preserve_rng_state():                                        │
│ ❱ 128 │   │   │   │   model(*new_inputs)  # additional warmup needed when input is mutated by    │
│   129 │   │   │   │   f = cudagraphify_impl(                                                     │
│   130 │   │   │   │   │   model=lambda args: model(*args), inputs=new_inputs, static_input_idx   │
│   131 │   │   │   │   )                                                                          │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /fx/graph_module.py:660 in call_wrapped                                                          │
│                                                                                                  │
│   657 │   │   │   cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined   │
│   658 │   │                                                                                      │
│   659 │   │   def call_wrapped(self, *args, **kwargs):                                           │
│ ❱ 660 │   │   │   return self._wrapped_call(self, *args, **kwargs)                               │
│   661 │   │                                                                                      │
│   662 │   │   cls.__call__ = call_wrapped                                                        │
│   663                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /fx/graph_module.py:269 in __call__                                                              │
│                                                                                                  │
│   266 │   │   │   if self.cls_call is not None:                                                  │
│   267 │   │   │   │   return self.cls_call(obj, *args, **kwargs)                                 │
│   268 │   │   │   else:                                                                          │
│ ❱ 269 │   │   │   │   return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[mi   │
│   270 │   │   except Exception as e:                                                             │
│   271 │   │   │   assert e.__traceback__                                                         │
│   272 │   │   │   topmost_framesummary: traceback.FrameSummary = \                               │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /nn/modules/module.py:1488 in _call_impl                                                         │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│ in forward                                                                                       │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/kernl │
│ /optimizer/attention.py:51 in attention_wrapper                                                  │
│                                                                                                  │
│    48 │   │   else:                                                                              │
│    49 │   │   │   attention_reference(q, k, v, output, sm_scale, is_causal=is_causal, attentio   │
│    50 │   else:                                                                                  │
│ ❱  51 │   │   attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask   │
│    52 │                                                                                          │
│    53 │   if extend_head:                                                                        │
│    54 │   │   output = output.squeeze(dim=1)                                                     │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/kernl │
│ /implementations/attention.py:550 in attention_forward                                           │
│                                                                                                  │
│   547 │   is_causal: bool = False,                                                               │
│   548 │   attention_mask: Optional[torch.Tensor] = None,                                         │
│   549 ):                                                                                         │
│ ❱ 550 │   return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask)           │
│   551                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /autograd/function.py:508 in apply                                                               │
│                                                                                                  │
│   505 │   │   if not torch._C._are_functorch_transforms_active():                                │
│   506 │   │   │   # See NOTE: [functorch vjp and autograd interaction]                           │
│   507 │   │   │   args = _functorch.utils.unwrap_dead_wrappers(args)                             │
│ ❱ 508 │   │   │   return super().apply(*args, **kwargs)                                          │
│   509 │   │                                                                                      │
│   510 │   │   if cls.setup_context == _SingleLevelFunction.setup_context:                        │
│   511 │   │   │   raise RuntimeError(                                                            │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch │
│ /cuda/amp/autocast_mode.py:105 in decorate_fwd                                                   │
│                                                                                                  │
│   102 │   │   │   args[0]._fwd_used_autocast = False                                             │
│   103 │   │   │   if autocast_context:                                                           │
│   104 │   │   │   │   with autocast(enabled=False):                                              │
│ ❱ 105 │   │   │   │   │   return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))    │
│   106 │   │   │   else:                                                                          │
│   107 │   │   │   │   return fwd(*args, **kwargs)                                                │
│   108 │   return decorate_fwd                                                                    │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/kernl │
│ /implementations/attention.py:511 in forward                                                     │
│                                                                                                  │
│   508 │   │   │   HAS_MASK = True                                                                │
│   509 │   │   │   IS_MATRIX_MASK = attention_mask.size(2) != 1                                   │
│   510 │   │                                                                                      │
│ ❱ 511 │   │   _fwd_kernel[grid](  # can't use name args because of the way autotune is impleme   │
│   512 │   │   │   head_size,  # heads                                                            │
│   513 │   │   │   m_size,  # m_size                                                              │
│   514 │   │   │   n_size,  # n_size                                                              │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/runtime/jit.py:106 in launcher                                                                 │
│                                                                                                  │
│   103 │   │   memorizes the grid.                                                                │
│   104 │   │   """                                                                                │
│   105 │   │   def launcher(*args, **kwargs):                                                     │
│ ❱ 106 │   │   │   return self.run(*args, grid=grid, **kwargs)                                    │
│   107 │   │   return launcher                                                                    │
│   108                                                                                            │
│   109                                                                                            │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/runtime/autotuner.py:73 in run                                                                 │
│                                                                                                  │
│    70 │   │   │   │   # prune configs                                                            │
│    71 │   │   │   │   pruned_configs = self.prune_configs(kwargs)                                │
│    72 │   │   │   │   bench_start = time.time()                                                  │
│ ❱  73 │   │   │   │   timings = {config: self._bench(*args, config=config, **kwargs)             │
│    74 │   │   │   │   │   │      for config in pruned_configs}                                   │
│    75 │   │   │   │   bench_end = time.time()                                                    │
│    76 │   │   │   │   self.bench_time = bench_end - bench_start                                  │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/runtime/autotuner.py:73 in <dictcomp>                                                          │
│                                                                                                  │
│    70 │   │   │   │   # prune configs                                                            │
│    71 │   │   │   │   pruned_configs = self.prune_configs(kwargs)                                │
│    72 │   │   │   │   bench_start = time.time()                                                  │
│ ❱  73 │   │   │   │   timings = {config: self._bench(*args, config=config, **kwargs)             │
│    74 │   │   │   │   │   │      for config in pruned_configs}                                   │
│    75 │   │   │   │   bench_end = time.time()                                                    │
│    76 │   │   │   │   self.bench_time = bench_end - bench_start                                  │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/runtime/autotuner.py:63 in _bench                                                              │
│                                                                                                  │
│    60 │   │   │   │   config.pre_hook(self.nargs)                                                │
│    61 │   │   │   self.hook(args)                                                                │
│    62 │   │   │   self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,   │
│ ❱  63 │   │   return do_bench(kernel_call)                                                       │
│    64 │                                                                                          │
│    65 │   def run(self, *args, **kwargs):                                                        │
│    66 │   │   self.nargs = dict(zip(self.arg_names, args))                                       │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/testing.py:177 in do_bench                                                                     │
│                                                                                                  │
│   174 │   │   cache.zero_()                                                                      │
│   175 │   │   # record time of `fn`                                                              │
│   176 │   │   start_event[i].record()                                                            │
│ ❱ 177 │   │   fn()                                                                               │
│   178 │   │   end_event[i].record()                                                              │
│   179 │   # Record clocks                                                                        │
│   180 │   torch.cuda.synchronize()                                                               │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/runtime/autotuner.py:62 in kernel_call                                                         │
│                                                                                                  │
│    59 │   │   │   if config.pre_hook:                                                            │
│    60 │   │   │   │   config.pre_hook(self.nargs)                                                │
│    61 │   │   │   self.hook(args)                                                                │
│ ❱  62 │   │   │   self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages,   │
│    63 │   │   return do_bench(kernel_call)                                                       │
│    64 │                                                                                          │
│    65 │   def run(self, *args, **kwargs):                                                        │
│                                                                                                  │
│ /home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/trito │
│ n/runtime/autotuner.py:200 in run                                                                │
│                                                                                                  │
│   197 │   def run(self, *args, **kwargs):                                                        │
│   198 │   │   for v, heur in self.values.items():                                                │
│   199 │   │   │   kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})                │
│ ❱ 200 │   │   return self.fn.run(*args, **kwargs)                                                │
│   201                                                                                            │
│   202                                                                                            │
│   203 def heuristics(values):                                                                    │
│ in _fwd_kernel                                                                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyboardInterrupt

Your environment

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+



### Self-service

- [x] I would be willing to help fix this bug myself.

### Code of Conduct

- [X] I agree to follow this project's Code of Conduct
jonathlela commented 1 year ago

Hi @TheSeamau5 ,

How long does it hangs ? Without a warmup phase, model.generate takes several minutes to finish. I've tried on my side with this model on A10, after a warm up phase of ~5 minutes, subsequent calls take less than a second.

TheSeamau5 commented 1 year ago

Hi @jonathlela, thanks you for your swift response.

I didn't realize that the warmup phase took multiple minutes. I guess I don't really understand how it works.

So, tried it again on A10 and got 443.72s (7min 24s) for the first call of

model_output = model.generate(
  **inputs, 
  min_length=22, max_length=22
)

I noticed several times that it is inconsistent on whether or not subsequent calls take 7min or take 0.1s.

Critically, I have failed to run it with max_length=512 within a reasonable time frame (< 30 min) even once, which at the end of the day is what I'm trying to do.

TheSeamau5 commented 1 year ago

Ok, so I rewrote the script to first pass a warmup prompt to the model and then pass other different prompts to the model.

It looks like re-running the model on the warmup prompt is fast but re-running the model on a new prompt is slow.

Is this how the library is supposed to be used? I was trying to follow this: https://github.com/ELS-RD/kernl/blob/main/tutorial/t5%20e2e.ipynb

Solution I'm looking for is something for which, sure, I can spend any amount of time once at "build time" running all sorts of warmup phases and optimizations, but then after that, the model should be fast(er) and will only see new inputs.

Thank you for your help, and I apologize if it feels like I didn't understand how the library works, I am new to optimizing PyTorch models.

# Standard Library imports
import os
import sys
import time

# Third-party Library imports
from rich import print
from rich.markdown import Markdown
from rich.syntax import Syntax
import torch
import torch._dynamo as torchdynamo
from transformers import AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration
from kernl.model_optimization import optimize_model
from tqdm import tqdm

# Model Name
MODEL_NAME = "Salesforce/codet5-large-ntp-py"

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Download tokenizer from HuggingFace
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Download raw model from HuggingFace
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).eval().cuda()

# default cache size needs to be increased to store the many graphs with generative models
torchdynamo.config.cache_size_limit = 512

# Optimize the model with kernl
optimize_model(model.encoder)
optimize_model(model.decoder)

# Function to generate a completion 
def lm(prompt: str, **kwargs) -> str:
    # Tokenize the input
    inputs = tokenizer(
        prompt, 
        return_tensors="pt", 
        # pad_to_multiple_of=8, 
        # padding=True
    ).to(device)

    # Compute the generation
    with torch.inference_mode(), torch.cuda.amp.autocast():
        torch.cuda.synchronize()
        model_output = model.generate(
            **inputs, 
            **kwargs
        )
        torch.cuda.synchronize()

    # Decode the output
    completion = tokenizer.decode(model_output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

    # Return the output
    return completion

# This is probably a hack
os.environ["TOKENIZERS_PARALLELISM"] = "true"

prompt1 = """
# List all currently running unencrypted EC2 instances
# An unencrypted EC2 instance is an instance with at least one unencrypted EBS volume
# Step 1: List all unencrypted EBS volumes and get the list of attached EC2 instances
# Step 2: Return the set of unique EC2 instances with at least one unencrypted EBS volume
import boto3
""".strip()

prompt2 = """
# Retrieve list of buckets from S3
import boto3
""".strip()

warmup_prompt = """
# Retrieve list of instances from EC2
import boto3
"""

# Prompts we will test with
prompts = [
    prompt1,
    prompt2
]

#############
# MAIN LOOP #
#############

# Main Warm-up phasea
# warmup (IRL, encoder and decoder should be warmed each on their own)
print(Markdown("# Warmup Phase"))
start = time.perf_counter()
lm(warmup_prompt, min_length=22, max_length=22)
print(f" - Warmup completed in: {time.perf_counter() - start}")

# Second warm-up
print(Markdown("# Second warm-up"))
for _ in tqdm(range(10), desc="Warm-up runs"):
    lm(warmup_prompt, min_length=22, max_length=22)

# Actual Run
print(Markdown("# Generations"))
for prompt in tqdm(prompts, desc="Generations"):
    # Compute the completion
    completion = lm(
        prompt, 
        # max_length=512
        min_length=22, max_length=22
    )
TheSeamau5 commented 1 year ago

Btw, the output I'm getting (output completions are correct)

(ubuntu-py3.9) ubuntu@152-70-121-233:~$ poetry run python kernl_repro2.py 
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                                                                                   Warmup Phase                                                                                                   ┃
┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
/home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch/cuda/graphs.py:82: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super(CUDAGraph, self).capture_end()
 - Warmup completed in: 404.04018353799984
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                                                                                  Second warm-up                                                                                                  ┃
┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
Warm-up runs: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.80it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                                                                                   Generations                                                                                                    ┃
┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
Generations:   0%|                                                                                                                                                                            | 0/2 [00:00<?, ?it/s]/home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch/cuda/graphs.py:82: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super(CUDAGraph, self).capture_end()
Generations:  50%|█████████████████████████████████████████████████████████████████████████████████▌                                                                                 | 1/2 [05:24<05:24, 324.80s/it]/home/ubuntu/.cache/pypoetry/virtualenvs/ubuntu-zk_aSFMD-py3.9/lib/python3.9/site-packages/torch/cuda/graphs.py:82: UserWarning: The CUDA Graph is empty. This ususally means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:191.)
  super(CUDAGraph, self).capture_end()
Generations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [11:29<00:00, 344.52s/it]
jonathlela commented 1 year ago

Maybe your new prompts have a different size from your prompts in the warm-up phase. The warm-up phase should capture every shape of your actual input. For example, if your maximum prompt size is 32 tokens, during the warm-up phase you should have a prompt from 0-8 tokens, 9-16 tokens, 17-24 tokens and 25-32 tokens.

jonathlela commented 1 year ago

Feel free to reopen if needed.