ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.83k stars 827 forks source link

Generating after LORA training CAN NOT Stop Properly #756

Closed jenkinv closed 4 months ago

jenkinv commented 4 months ago

Generating after lora training CAN NOT Stop Properly

The code at lora/data/wikisql.py removes the bos_token and eos_token, assuming the tokenizer will add them automatically. However, this is not the case for all tokenizers, as demonstrated with the Mistral-7B-v0.1 tokenizer. This leads to problems where the generated text doesn't stop properly after training with the wikisql dataset.

from utils import load

_, tokenizer, _ = load("mistralai/Mistral-7B-v0.1")

print(tokenizer.encode("a"))

This will output the sequence with bos_token, but without eos_token:

[1, 264]

To resolve this issue, we need to explicitly enable the addition of eos_token for the tokenizer. Here's the corrected code snippet:

from utils import load

_, tokenizer, _ = load("mistralai/Mistral-7B-v0.1")

# Enable adding eos_token
tokenizer.add_eos_token = True

print(tokenizer.encode("a"))

This will output the correct sequence with both bos_token and eos_token:

[1, 264, 2]

where:

Therefore, we need to add the following line to mlx-examples/lora/lora.py within the train function:

# Add this line to turn on add_eos_token
tokenizer.add_eos_token = True

This ensures the model is trained with proper sequence termination and generates complete text after training with the wikisql dataset.

This solution is specifically tested with the Mistral-7B tokenizer and may need adjustments for other tokenizers.

jenkinv commented 4 months ago

Generating after lora training CAN NOT Stop Properly

The code at lora/data/wikisql.py removes the bos_token and eos_token, assuming the tokenizer will add them automatically. However, this is not the case for all tokenizers, as demonstrated with the Mistral-7B-v0.1 tokenizer. This leads to problems where the generated text doesn't stop properly after training with the wikisql dataset.

from utils import load

_, tokenizer, _ = load("mistralai/Mistral-7B-v0.1")

print(tokenizer.encode("a"))

This will output the sequence with bos_token, but without eos_token:

[1, 264]

To resolve this issue, we need to explicitly enable the addition of eos_token for the tokenizer. Here's the corrected code snippet:

from utils import load

_, tokenizer, _ = load("mistralai/Mistral-7B-v0.1")

# Enable adding eos_token
tokenizer.add_eos_token = True

print(tokenizer.encode("a"))

This will output the correct sequence with both bos_token and eos_token:

[1, 264, 2]

where:

  • 1 is the id of bos_token
  • 264 is the id of 'a'
  • 2 is the id of eos_token

Therefore, we need to add the following line to mlx-examples/lora/lora.py within the train function:

# Add this line to turn on add_eos_token
tokenizer.add_eos_token = True

This ensures the model is trained with proper sequence termination and generates complete text after training with the wikisql dataset.

This solution is specifically tested with the Mistral-7B tokenizer and may need adjustments for other tokenizers.

python lora.py --model mistralai/Mistral-7B-v0.1 \
               --adapter-file adapters.npz \
               --max-tokens 50 \
               --temp 0.2 \
               --prompt "table: Order
columns: Name,City,Amount,Category,Date
Q: Tomorrow is 2024/05/01.What is the total amount of HangZhou yesterday?
A: "
Loading pretrained model
Fetching 10 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 82080.31it/s]
Total parameters 7243.436M
Trainable parameters 1.704M
Loading datasets
Generating
table: Order
columns: Name,City,Amount,Category,Date
Q: Tomorrow is 2024/05/01.What is the total amount of HangZhou yesterday?
A: SELECT SUM Amount FROM Order WHERE City = 'HangZhou' AND Date = '2024/05/01'
Q: What is the total amount of HangZhou yesterday?
A:
==========

This is the case that generating doesn't work properly.

awni commented 4 months ago

It might be good to make it an option and default enable it.

Also for the example case you showed, does training with the eos token fix it?

jenkinv commented 4 months ago

It might be good to make it an option and default enable it.

Good advice, I would follow it.

Also for the example case you showed, does training with the eos token fix it?

Yes, training with the eos token fix it.

This solution has been confirmed to work for Mistral-7B, gemma-2b, and MiniCPM-2B. However, it may not be compatible with Meta-Llama-3-8B and Qwen1.5-4B. Further investigation is required to determine the specific reasons for this incompatibility and develop appropriate solutions.

760