elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Support DeepSeek Coder Model #278

Closed jonastemplestein closed 8 months ago

jonastemplestein commented 8 months ago

Hey folks, I'm trying to use the deepseek-coder-1.3b-base model with bumblebee. I was delighted to find that the model, tokenizer and generation_config all load. But when trying to run inference I get the following error that's a bit hard for me to debug:

repo = {:hf, "deepseek-ai/deepseek-coder-1.3b-base"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)
prompt = "hello world"
Nx.Serving.run(serving, prompt)
** (ErlangError) Erlang error: "Could not decode field on position 1"
    (tokenizers 0.4.0) Tokenizers.Native.encoding_transform(#Tokenizers.Encoding<[length: 2, ids: [31702, 1835]]>, [pad: {2, [pad_id: nil, pad_token: "</s>", direction: :left]}])
    (elixir 1.15.7) lib/enum.ex:1693: Enum."-map/2-lists^map/1-1-"/2
    (bumblebee 0.4.2) lib/bumblebee/utils/tokenizers.ex:51: Bumblebee.Utils.Tokenizers.apply/4
    (nx 0.6.2) lib/nx.ex:4510: Nx.with_default_backend/2
    (bumblebee 0.4.2) lib/bumblebee/text/generation.ex:882: anonymous fn/4 in Bumblebee.Text.Generation.generation/4
    (nx 0.6.2) lib/nx/serving.ex:1704: anonymous fn/3 in Nx.Serving.handle_preprocessing/2
    (telemetry 1.2.1) /Users/jonas/Library/Caches/mix/installs/elixir-1.15.7-erts-14.1.1/f67c01eefcd351fd5b5511a96e61c42d/deps/telemetry/src/telemetry.erl:321: :telemetry.span/3
    #cell:776q3ifvc2hexaoavrvlcde7ehfkvusl:7: (file)

I'm using bumblebee 0.4.2

Here's the model spec

spec: %Bumblebee.Text.Llama{
    architecture: :for_causal_language_modeling,
    vocab_size: 32256,
    max_positions: 16384,
    hidden_size: 2048,
    intermediate_size: 5504,
    num_blocks: 24,
    num_attention_heads: 16,
    activation: :silu,
    layer_norm_epsilon: 1.0e-6,
    initializer_scale: 0.02,
    output_hidden_states: false,
    output_attentions: false,
    num_labels: 2,
    id_to_label: %{},
    pad_token_id: 0
  }

And here's the tokenizer

%Bumblebee.Text.LlamaTokenizer{
  tokenizer: #Tokenizers.Tokenizer<[
    vocab_size: 32022,
    byte_fallback: false,
    continuing_subword_prefix: nil,
    dropout: nil,
    end_of_word_suffix: nil,
    fuse_unk: false,
    model_type: "bpe",
    unk_token: nil
  ]>,
  special_tokens: %{pad: "</s>", eos: "</s>", sep: "</s>", unk: "<unk>"},
  additional_special_tokens: []
}

It looks like the vocab size is not correct in the model spec, for example.

I think the tokenizer uses the correct vocabulary, because I can run this:

Bumblebee.Tokenizer.decode(tokenizer, [32015])

and it correctly returns <|fim▁hole|> , which is a deepseek specific token

Would be amazing if this model was supported, as deepseek-coder actually seems to be pretty good at elixir out of the box 🙇

Thank you so much for your help!

seanmor5 commented 8 months ago

The reason this happens is that their repository doesn't have a special tokens map file. I added it here: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-base/discussions/1. You should be able to clone that branch and then load the tokenizer from {:local, "/path/to/local/model"} and the generation will work

The special tokens (padding, eos, bos, etc.) are in the tokenizer_config.json. Not sure if we can load from there. cc @jonatanklosko

jonatanklosko commented 8 months ago

@seanmor5 apparently they want to migrate off of these files

All the added token information now lies in the tokenizer_config.json. Nuking the special_tokens_map.json and added_tokens.json. ~ https://github.com/huggingface/transformers/pull/23909

I will later revisit the loading logic.


@jonastemplestein meanwhile you can load from the PR commit directly with {:hf, "deepseek-ai/deepseek-coder-1.3b-base", revision: "7abe797c62ede1b47c4d00a17ed006be9659d657"}.

jonastemplestein commented 8 months ago

Thank you so much you two! <3

jonastemplestein commented 8 months ago

Hey folks, I've just had a chance to play with this and keep getting garbage output from the model compared to what I get using the HF transformers library

For example, running this in livebook:


# setup cell

Mix.install(
  [
    {:kino_bumblebee, "~> 0.4.0"},
    {:exla, ">= 0.0.0"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

# subsequent code cell

repo = {:hf, "deepseek-ai/deepseek-coder-1.3b-base", revision: "7abe797c62ede1b47c4d00a17ed006be9659d657"}
{:ok, model_info} = Bumblebee.load_model(repo, backend: {EXLA.Backend, client: :host})
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

serving =
  Bumblebee.Text.generation(model_info, tokenizer, generation_config,
    compile: [batch_size: 1, sequence_length: 1028],
    stream: false,
    defn_options: [compiler: EXLA, lazy_transfers: :never]
  )

# # Should be supervised
Kino.start_child({Nx.Serving, name: Deepseek, serving: serving})

prompt = "<|fim▁begin|>def quick_sort(arr):\n  <|fim▁hole|> \n <|fim▁end|> "

Nx.Serving.batched_run(Deepseek, prompt)

results in this output:

<|fim▁begin|>def quick_sort(arr):\n  <|fim▁hole|> \n <|fim▁end|> ••••••••••••••••••••

In comparison, running the same model in HF transformers in python gives me something completely different and sensible:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-1.3b-base", trust_remote_code=True).cuda()

input_text = "<|fim▁begin|>def quick_sort(arr):\n  <|fim▁hole|> \n <|fim▁end|> "

inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=328)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_text):])

Output:

 if len(arr) <= 1:
        return arr
    else:
        pivot = arr[0]
        less = [i for i in arr[1:] if i <= pivot]
        greater = [i for i in arr[1:] if i > pivot]
        return quick_sort(less) + [pivot] + quick_sort(greater)

Any ideas how to further debug this? 🤔 Any help is much appreciated

I'm on an M2 Pro in case that is relevant, but I also saw the same behaviour when I briefly ran this on an A100 with cuda

jonastemplestein commented 8 months ago

Also, slightly related question: Do you think I can use the fine-tuning process described here to fine-tune this model directly from livebook? Can you think of any reasons why that might not work?

I think it'd be a very cool demo to train an elixir code completion model for livebook in livebook :)

jonatanklosko commented 8 months ago

@jonastemplestein I tracked down the output difference and it turns out this checkpoint uses scaling for rotary embedding, which we don't have yet. I inlined scaling just to check and the output looks the same, except there is a bunch of newlines at the end (which makes the generation run for way too long). I will send a PR once I figure it out and get everything in sync :)

jonastemplestein commented 8 months ago

Amazing, thank you so much!

Regarding the excessive newlines, I've found that one way to reduce that in python land is to set a higher "repeat penalty" for token generation. Is there such a parameter in bumblebee? I saw there is :no_repeat_ngram_length but that works a bit differently.

Along similar lines, are there any plans to support a temperature parameter? Or can I somehow simulate the way temperature works using the sampling strategy?

jonatanklosko commented 8 months ago

@jonastemplestein with #285 it should match the Python implementation, you also need to load the tokenizer from this revision until upstream PR is merged:

{:ok, tokenizer} =
  Bumblebee.load_tokenizer(
    {:hf, "deepseek-ai/deepseek-coder-1.3b-base",
     revision: "e94f2b11bc28abbd67ecadfaad058c30b24a589f"}
  )

We don't have repetition penalty, only :no_repeat_ngram_length.

We don't have temperature either, but supporting that is trivial, I will open up an issue and implement soon.

jonastemplestein commented 8 months ago

Amazing, thank you so much!

If you'd like, you can see the model in action in this livebook instance, which has rudimentary code completion: https://livebookjonas.fly.dev/ (password elixircopilot)

The livebook instance powers down after some inactivity and then takes 30 seconds or so to come back up. If it just booted, you can open the starred notebook called Livebook Copilot Playground to play around.

Two questions

  1. Could you make the same huggingface PR for the 6.7b version, please?
  2. Can you think of a reason why this model would return not just new tokens, but also the entire prompt? I had to do this workaround but that wasn't necessary e.g. for codellama or even GPT2

Thanks again for your help with this, super amazing <3

jonatanklosko commented 8 months ago

Could you make the same huggingface PR for the 6.7b version, please?

Done. Note that the tokenizer should be the same, so you can load the tokenizer as above and model from any other checkpoint :)

jonatanklosko commented 8 months ago

Can you think of a reason why this model would return not just new tokens, but also the entire prompt? I had to do this workaround but that wasn't necessary e.g. for codellama or even GPT2

That's how text completion models usually work, we pass the input sequence, they generate tokens one by one, which we effectively keep appending to the input sentence for subsequent inferences. There are also encoder-decoder models, where the input first goes through an encoder and is treated as the context for generation, separate from the generation sequence (e.g. BART). That said, we should have an option to strip the tokens, tracked by #247 :)