replit / ReplitLM

Inference code and configs for the ReplitLM model family
https://huggingface.co/replit
Apache License 2.0
918 stars 75 forks source link

Speed up the inference #24

Open Symbolk opened 1 year ago

Symbolk commented 1 year ago

Hi, this model seems nice, but I do find that the inference speed is very slow (70ms/token on single A100), so I want to speed up it.

It seems to be related with MPT itself: https://huggingface.co/mosaicml/mpt-7b-instruct/discussions/23

Any suggestions or best practices on speeding up? E.g., FastTransformer (a bit low-level), ONNX Runtime, or Oneflow?

madhavatreplit commented 1 year ago

Are you using triton flash attention, bfloat16 as described in the model's huggingface README?

You can also accelerate with Fastertransformers with this model.

Symbolk commented 1 year ago

Are you using triton flash attention, bfloat16 as described in the model's huggingface README?

You can also accelerate with Fastertransformers with this model.

Thanks for reply! I will try the right configuration for triton and bfloat16, with them enabled, how many milliseconds per token should I expect on A100-80G or V100-32G?

Symbolk commented 1 year ago

Hi, I enabled triton and bfloat16, inside the docker provided here: https://github.com/mosaicml/llm-foundry/, with dependencies installed, but the error is thrown like this:


0it [00:03, ?it/s]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ in _fwd_kernel:21                                                            │
╰──────────────────────────────────────────────────────────────────────────────╯
KeyError: 
('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-2b0c5161c53c71b37ae20a9996ee4bb8-c1
f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037
f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff51
98-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', 
(torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16,
torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 
'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 
'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, False, 
128, 128), (True, True, True, True, True, True, True, (False,), (True, False), 
(True, False), (True, False), (True, False), (True, False), (True, False), 
(True, False), (True, False), (True, False), (True, False), (False, False), 
(True, False), (True, False), (True, False), (True, False), (True, False), 
(False, False), (False, False), (True, False), (True, False), (False, False), 
(False, False)))

During handling of the above exception, another exception occurred:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /cache/pretrained_model/replit_new/eval.py:176 in <module>                   │
│                                                                              │
│   173 if __name__ == '__main__':                                             │
│   174 │                                                                      │
│   175 │   logger.info(f'CUDA version: {torch.version.cuda}')                 │
│ ❱ 176 │   run()                                                              │
│   177                                                                        │
│                                                                              │
│ /cache/pretrained_model/replit_new/eval.py:134 in run                        │
│                                                                              │
│   131 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │      │
│   132 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │      │
│   133 │   │   start_time = time.time()                                       │
│ ❱ 134 │   │   generated_snippet = model.generate(x, max_length=256,          │
│   135 │   │   │   │   │   │   │   │   │   │      do_sample=True, use_cache=F │
│   136 │   │   │   │   │   │   │   │   │   │      num_return_sequences=1, eos │
│   137 │   │   end_time = time.time()                                         │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/autograd/grad_mode.py:27 in             │
│ decorate_context                                                             │
│                                                                              │
│    24 │   │   @functools.wraps(func)                                         │
│    25 │   │   def decorate_context(*args, **kwargs):                         │
│    26 │   │   │   with self.clone():                                         │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                           │
│    28 │   │   return cast(F, decorate_context)                               │
│    29 │                                                                      │
│    30 │   def _wrap_generator(self, func):                                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/transformers/generation/utils.py:1572 in      │
│ generate                                                                     │
│                                                                              │
│   1569 │   │   │   )                                                         │
│   1570 │   │   │                                                             │
│   1571 │   │   │   # 13. run sample                                          │
│ ❱ 1572 │   │   │   return self.sample(                                       │
│   1573 │   │   │   │   input_ids,                                            │
│   1574 │   │   │   │   logits_processor=logits_processor,                    │
│   1575 │   │   │   │   logits_warper=logits_warper,                          │
│                                                                              │
│ /usr/lib/python3/dist-packages/transformers/generation/utils.py:2619 in      │
│ sample                                                                       │
│                                                                              │
│   2616 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_i │
│   2617 │   │   │                                                             │
│   2618 │   │   │   # forward pass to get next token                          │
│ ❱ 2619 │   │   │   outputs = self(                                           │
│   2620 │   │   │   │   **model_inputs,                                       │
│   2621 │   │   │   │   return_dict=True,                                     │
│   2622 │   │   │   │   output_attentions=output_attentions,                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/mo │
│ deling_mpt.py:239 in forward                                                 │
│                                                                              │
│   236 │   def forward(self, input_ids: torch.LongTensor, past_key_values: Op │
│   237 │   │   return_dict = return_dict if return_dict is not None else self │
│   238 │   │   use_cache = use_cache if use_cache is not None else self.confi │
│ ❱ 239 │   │   outputs = self.transformer(input_ids=input_ids, past_key_value │
│   240 │   │   logits = F.linear(outputs.last_hidden_state, self.transformer. │
│   241 │   │   if self.logit_scale is not None:                               │
│   242 │   │   │   if self.logit_scale == 0:                                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/mo │
│ deling_mpt.py:185 in forward                                                 │
│                                                                              │
│   182 │   │   │   │   assert all_hidden_states is not None                   │
│   183 │   │   │   │   all_hidden_states = all_hidden_states + (x,)           │
│   184 │   │   │   past_key_value = past_key_values[b_idx] if past_key_values │
│ ❱ 185 │   │   │   (x, past_key_value) = block(x, past_key_value=past_key_val │
│   186 │   │   │   if past_key_values is not None:                            │
│   187 │   │   │   │   past_key_values[b_idx] = past_key_value                │
│   188 │   │   x = self.norm_f(x)                                             │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/bl │
│ ocks.py:36 in forward                                                        │
│                                                                              │
│   33 │                                                                       │
│   34 │   def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[t │
│   35 │   │   a = self.norm_1(x)                                              │
│ ❱ 36 │   │   (b, _, past_key_value) = self.attn(a, past_key_value=past_key_v │
│   37 │   │   x = x + self.resid_attn_dropout(b)                              │
│   38 │   │   m = self.norm_2(x)                                              │
│   39 │   │   n = self.ffn(m)                                                 │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194 in _call_impl │
│                                                                              │
│   1191 │   │   # this function, and just call forward.                       │
│   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                     │
│   1195 │   │   # Do not call functions when jit is used                      │
│   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1197 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/at │
│ tention.py:172 in forward                                                    │
│                                                                              │
│   169 │   │   │   past_key_value = (key, value)                              │
│   170 │   │   if attn_bias is not None:                                      │
│   171 │   │   │   attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1): │
│ ❱ 172 │   │   (context, attn_weights) = self.attn_fn(query, key, value, self │
│   173 │   │   return (self.out_proj(context), attn_weights, past_key_value)  │
│   174                                                                        │
│   175 class MultiQueryAttention(nn.Module):                                  │
│                                                                              │
│ /home/mosaicml/.cache/huggingface/modules/transformers_modules/replit_new/at │
│ tention.py:111 in triton_flash_attn_fn                                       │
│                                                                              │
│   108 │   │   key = key.expand(*key.shape[:2], n_heads, key.size(-1))        │
│   109 │   │   value = value.expand(*value.shape[:2], n_heads, value.size(-1) │
│   110 │   reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_ │
│ ❱ 111 │   attn_output = flash_attn_triton.flash_attn_func(query, key, value, │
│   112 │   output = attn_output.view(*attn_output.shape[:2], -1)              │
│   113 │   return (output, None)                                              │
│   114                                                                        │
│                                                                              │
│ /usr/lib/python3/dist-packages/flash_attn/flash_attn_triton.py:810 in        │
│ forward                                                                      │
│                                                                              │
│   807 │   │   """                                                            │
│   808 │   │   # Make sure that the last dimension is contiguous              │
│   809 │   │   q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in │
│ ❱ 810 │   │   o, lse, ctx.softmax_scale = _flash_attn_forward(               │
│   811 │   │   │   q, k, v, bias=bias, causal=causal, softmax_scale=softmax_s │
│   812 │   │   )                                                              │
│   813 │   │   ctx.save_for_backward(q, k, v, o, lse, bias)                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/flash_attn/flash_attn_triton.py:623 in        │
│ _flash_attn_forward                                                          │
│                                                                              │
│   620 │   BLOCK = 128                                                        │
│   621 │   num_warps = 4 if d <= 64 else 8                                    │
│   622 │   grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch │
│ ❱ 623 │   _fwd_kernel[grid](                                                 │
│   624 │   │   q, k, v, bias, o,                                              │
│   625 │   │   lse, tmp,                                                      │
│   626 │   │   softmax_scale,                                                 │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/runtime/jit.py:106 │
│ in launcher                                                                  │
│                                                                              │
│   103 │   │   memorizes the grid.                                            │
│   104 │   │   """                                                            │
│   105 │   │   def launcher(*args, **kwargs):                                 │
│ ❱ 106 │   │   │   return self.run(*args, grid=grid, **kwargs)                │
│   107 │   │   return launcher                                                │
│   108                                                                        │
│   109                                                                        │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/runtime/autotuner. │
│ py:200 in run                                                                │
│                                                                              │
│   197 │   def run(self, *args, **kwargs):                                    │
│   198 │   │   for v, heur in self.values.items():                            │
│   199 │   │   │   kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwa │
│ ❱ 200 │   │   return self.fn.run(*args, **kwargs)                            │
│   201                                                                        │
│   202                                                                        │
│   203 def heuristics(values):                                                │
│ in _fwd_kernel:41                                                            │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/compiler.py:1268   │
│ in compile                                                                   │
│                                                                              │
│   1265 │   if warm_cache_only:                                               │
│   1266 │   │   return  # load_binary() requires a valid cuda context         │
│   1267 │                                                                     │
│ ❱ 1268 │   return CompiledKernel(name, so_cache_manager._make_path(so_name), │
│   1269                                                                       │
│   1270                                                                       │
│   1271 class CompiledKernel:                                                 │
│                                                                              │
│ /home/mosaicml/.local/lib/python3.10/site-packages/triton/compiler.py:1281   │
│ in __init__                                                                  │
│                                                                              │
│   1278 │   │   # initialize launcher                                         │
│   1279 │   │   import importlib.util                                         │
│   1280 │   │   spec = importlib.util.spec_from_file_location("launcher", so_ │
│ ❱ 1281 │   │   mod = importlib.util.module_from_spec(spec)                   │
│   1282 │   │   spec.loader.exec_module(mod)                                  │
│   1283 │   │   self.c_wrapper = getattr(mod, "launch")                       │
│   1284 │   │   # initialize metadata                                         │
│ in module_from_spec:571                                                      │
│ in create_module:1176                                                        │
│ in _call_with_frames_removed:241                                             │
╰──────────────────────────────────────────────────────────────────────────────╯
ImportError: 
/home/mosaicml/.triton/cache/ab77933fc177e6d77b0dd8896210d966/_fwd_kernel.so: 
undefined symbol: cuLaunchKernel