axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.98k stars 879 forks source link

chat_template masking is broken with Mistral Small (possibly others) #2089

Open kubernetes-bad opened 6 days ago

kubernetes-bad commented 6 days ago

Please check that this issue hasn't been reported before.

Expected Behavior

When using chat_template dataset type, given a dataset in sharegpt format, human turns should be masked (label -100) and gpt turns should NOT be masked when roles_to_train: ["gpt"] is set. How it should be (turns broken out one per line for readability):

<s>(-100, 1)
[INST](-100, 3) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)

Current behaviour

Turns are chaotically masked out mid-conversation, or the whole conversation is masked out - depending on the tokenizer type used.

How it is - LlamaTokenizer:

Also note the extra space tokens (-100, 29473)

<s>(-100, 1)
[INST](-100, 3) (-100, 29473) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) (-100, 29473)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) (-100, 29473) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) (-100, 29473)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) (-100, 29473) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) (-100, 29473)
NOT(-100, 6225) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3) (-100, 29473) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) (-100, 29473)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)
[INST](-100, 3) (-100, 29473) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) (-100, 29473)
NOT(-100, 6225) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3) (-100, 29473) M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) (-100, 29473)
NOT(6225, 6225) M(1119, 1119) ASK(17572, 17572) ED(2674, 2674) </s>(2, 2)

How it is - LlamaTokenizerFast/AutoTokenizer:

the whole thing but the very last </s> is masked out!

<s>(-100, 1)
[INST](-100, 3)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4) 
NOT(-100, 6225)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(-100, 6225)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(-100, 6225)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(-100, 6225)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(-100, 6225)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(-100, 2)
[INST](-100, 3)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) [/INST](-100, 4)
NOT(-100, 6225)  M(-100, 1119) ASK(-100, 17572) ED(-100, 2674) </s>(2, 2)

Steps to reproduce

Use the following dataset:

{"conversations":[{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"}]}
{"conversations":[{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"},{"from":"human","value":"MASKED"},{"from":"gpt","value":"NOT MASKED"}]}

(it has TWO samples because load_datasets in cli is off by one and won't start with just one sample)

Run python3 -m axolotl.cli.preprocess config.yaml --debug

Check out turn 6 where it should be not masked (it's gpt turn), yet it is. There are extra space tokens if using slow tokenizer as well.

Config yaml

base_model: mistralai/Mistral-Small-Instruct-2409
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true

datasets:
  - path: ./datasets/debug.json
    type: chat_template
    field_messages: conversations
    message_field_role: from
    message_field_content: value
    roles_to_train: ["gpt"]
    chat_template: mistral_v2v3
    train_on_eos: "turn"

dataset_prepared_path: ./debug_dataset
val_set_size: 0.0
output_dir: ./output/last

sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false

gradient_accumulation_steps: 1
micro_batch_size: 1
learning_rate: 0.000004

train_on_inputs: false
group_by_length: false

Possible solution

I have no possible solution, but I have some debugging tips.

If you modify file src/axolotl/prompt_strategies/chat_template.py and change LOG.setLevel(logging.INFO) to LOG.setLevel(logging.DEBUG), you can see some decisions that chat_template logic is doing.

Depending on which tokenizer is used, you'll get two very different outputs. Here's output from LlamaTokenizer:

Debug log ``` [tokenize_prompt:276] Processing turn 0: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=3, end=6 [tokenize_prompt:330] EOS token missing after turn {'from': 'human', 'value': 'MASKED'}. eos_idx: 12, turn_end_idx: 6 [tokenize_prompt:276] Processing turn 1: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=20, end=24 [tokenize_prompt:316] Labels set for range 20:24 [tokenize_prompt:318] Labels after processing turn 1: [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100] [tokenize_prompt:328] EOS token set for training at index 24 [tokenize_prompt:276] Processing turn 2: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=27, end=30 [tokenize_prompt:330] EOS token missing after turn {'from': 'human', 'value': 'MASKED'}. eos_idx: 36, turn_end_idx: 30 [tokenize_prompt:276] Processing turn 3: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=44, end=48 [tokenize_prompt:316] Labels set for range 44:48 [tokenize_prompt:318] Labels after processing turn 3: [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100] [tokenize_prompt:328] EOS token set for training at index 48 [tokenize_prompt:276] Processing turn 4: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=51, end=54 [tokenize_prompt:330] EOS token missing after turn {'from': 'human', 'value': 'MASKED'}. eos_idx: 60, turn_end_idx: 54 [tokenize_prompt:276] Processing turn 5: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=68, end=72 [tokenize_prompt:316] Labels set for range 68:72 [tokenize_prompt:318] Labels after processing turn 5: [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, -100] [tokenize_prompt:328] EOS token set for training at index 72 [tokenize_prompt:276] Processing turn 6: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 7: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=8, end=12 [tokenize_prompt:316] Labels set for range 8:12 [tokenize_prompt:318] Labels after processing turn 7: [-100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2] [tokenize_prompt:328] EOS token set for training at index 12 [tokenize_prompt:276] Processing turn 8: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=3, end=6 [tokenize_prompt:330] EOS token missing after turn {'from': 'human', 'value': 'MASKED'}. eos_idx: 12, turn_end_idx: 6 [tokenize_prompt:276] Processing turn 9: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=8, end=12 [tokenize_prompt:316] Labels set for range 8:12 [tokenize_prompt:318] Labels after processing turn 9: [-100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2] [tokenize_prompt:328] EOS token set for training at index 12 [tokenize_prompt:276] Processing turn 10: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=3, end=6 [tokenize_prompt:330] EOS token missing after turn {'from': 'human', 'value': 'MASKED'}. eos_idx: 12, turn_end_idx: 6 [tokenize_prompt:276] Processing turn 11: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=8, end=12 [tokenize_prompt:316] Labels set for range 8:12 [tokenize_prompt:318] Labels after processing turn 11: [-100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2] [tokenize_prompt:328] EOS token set for training at index 12 [tokenize_prompt:339] Final labels: [-100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6225, 1119, 17572, 2674, 2] ```

Note how at turn 6 it has Turn indices: start=-1, end=-1. After that turn, every single operation is wrong.

Here is another log from when LlamaTokenizerFast/AutoTokenizer is used:

Debug log ``` [tokenize_prompt:276] Processing turn 0: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 1: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:328] EOS token set for training at index -1 [tokenize_prompt:276] Processing turn 2: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 3: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:328] EOS token set for training at index -1 [tokenize_prompt:276] Processing turn 4: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 5: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:328] EOS token set for training at index -1 [tokenize_prompt:276] Processing turn 6: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 7: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:328] EOS token set for training at index -1 [tokenize_prompt:276] Processing turn 8: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 9: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:328] EOS token set for training at index -1 [tokenize_prompt:276] Processing turn 10: role=human, content=MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: False [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:276] Processing turn 11: role=gpt, content=NOT MASKED, train_turn=None, train_detail=None [tokenize_prompt:290] Should train: True [tokenize_prompt:296] Turn indices: start=-1, end=-1 [tokenize_prompt:328] EOS token set for training at index -1 [tokenize_prompt:339] Final labels: [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 2] ```

Here, it couldn't find a single turn, and masked everything out but the final token.

I did verify that it happens regardless of chat_template: mistral_v2v3 or chat_template: chatml is set. This also happens with tokenizer default template string.

Which Operating Systems are you using?

Python Version

Python 3.11.10

axolotl branch-commit

d356740

Acknowledgements

NanoCode012 commented 5 days ago

Thank you for the report. Let me look into this.