Open lopuhin opened 10 months ago
I'm not yet sure if this might be related to having an extra pad token in the model -- but if I remove it the error is still there and all the weights have original shape. HF is misbehaving so can't download the unmodified model right now.
Reopening as this does not look related to the custom tokenizer -- I can reproduce the issue on the original llama-2-13b-chat-hf
model.
A possible fix inspired by https://github.com/Lightning-AI/lit-gpt/issues/774 is to add a clone() call (committed in https://github.com/pytorch-labs/gpt-fast/commit/636cd767f0fa4d0e10ad456b67219a809f906dc2)
diff --git a/generate.py b/generate.py
index 7f30de0..cb4d7e6 100644
--- a/generate.py
+++ b/generate.py
@@ -161,7 +161,7 @@ def generate(
seq = empty
input_pos = torch.arange(0, T, device=device)
- next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
+ next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
if is_speculative:
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
This removes the error, but there is no speedup for prompt of 708 tokens when doing --compile_prefill
, but it's clear that prefill slow-down is noticeable compared to a shorter prompt.
Running 13b chat model on L4 GPU with
An error happens
Library versions:
It works fine without
--compile_prefill