Closed aerdem4 closed 1 year ago
Thanks for opening this issue, it's a really good catch!
I digged a bit into the code, here are some observations:
Surprisingly, the falcon model seems to be quite sensitive to (left) padding. I tried the prompt Hi! I would like to know if you can roleplay a human for me.
(from OASST dataset) which gives the following answers:
Hello, I am a human named h2oGPT. How may I assist you today?
Hello, I am a human. How can I help you?
When I add manual left padding to the code for chat generation: full_prompt = tokenizer.pad_token * 20 + full_prompt
, the chat window outputs the same prediction as during evaluation. This could be an issue with falcon's implementation of attention, seems as if attention mask is missing when computing F.scaled_dot_product_attention
. I will further look into this.
When I test batch size = 1 for inference, chat responses and validation prediction match.
Thanks for looking into it!
Mine doesn't match even if I set batch size to 1. But I try this after finetuning. Did you finetune before trying?
Attention mask is used unless alibi is None. (What is alibi?) Its tokenizer doesnt have a mask token, and last token is ' victoria'. So the code below puts a lot of victorias probably?
if tokenizer.unk_token_id is not None:
cfg._tokenizer_mask_token_id = tokenizer.unk_token_id
elif tokenizer.mask_token_id is not None:
cfg._tokenizer_mask_token_id = tokenizer.mask_token_id
elif tokenizer.pad_token_id is not None:
cfg._tokenizer_mask_token_id = tokenizer.pad_token_id
else:
# setting the mask token id to the last token in the vocabulary
# this usually is a safe choice and mostly refers to eos token
cfg._tokenizer_mask_token_id = len(tokenizer) - 1
Alibi is a specific technique that one can use: https://arxiv.org/abs/2108.12409
For falcon it is None by default https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json#L2
It seems attention mask is ignored if alibi is None. Sounds like the issue here.
cfg._tokenizer_mask_token_id
this is just something we use in LLM studio for mask augmentation, and it will be irrlevant for the issue at hand
I am surprised you cannot match bs=1 with chat after training. How are you trying it exactly?
Did you finetune before trying?
I was testing with h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3
model and training epochs == 0.
I set batch size to 1 on UI and train for 2 epochs. And I checked logs, it was doing the inference one by one. Maybe you can try training on a dataset you have for 1 epoch with BS 1 and see if you have the same issue.
Without training, the model doesn't put #s at the end. When I train with my data which has no code in it, it prints a lot of #s and some code pieces at the end during validation. Chat mode never does it, it has meaningful results.
Its tokenizer doesn't have a pad_token and there is this on the code:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
I would think this may cause learning something wrong during training but if it was the case, it could also effect the chat mode.
Another discussion I found on Github that may be relevant: https://github.com/lm-sys/FastChat/issues/1588
No this is not the issue, nearly no LLM has a dedicated pad_token, so setting it to eos is common practice. Also with attention mask those tokens are ignored anyways, if the attention mask is used, which it does not seem to be for falcon, which can cause these issues.
I now tried training with bs=1 and epochs=2 and chat and val predictions match 100% on a few samples I tried, also for coding
Having an extended attention mask that is both causal and masks left pad tokens is probably not possible, as the attention would then be undefined for the left pad tokens (they cannot attend to anything). (currently checking if doable).
During the finetuning process, the model should effectively learn to disregard the pad tokens, as evidenced by the varying but generally consistent results obtained from h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3
. Still, padding clearly affects the output distribution and may result in different generation predictions.
Output of a model trained with batch size 1.
Input Text: What are the top 5 countries with the largest population?
Target Text: China, India, United States, Indonesia, Pakistan.
Predicted Text: The top 5 countries with the largest population are China, India, the United States, Indonesia, and Pakistan. ##################################################################################################################################################################################################################################
Chat mode: The top 5 countries with the largest population are China, India, the United States, Indonesia, and Pakistan.
Could you please share your cfg.yaml @aerdem4
architecture:
backbone_dtype: float16
force_embedding_gradients: false
gradient_checkpointing: true
intermediate_dropout: 0.0
pretrained: true
pretrained_weights: ''
augmentation:
random_parent_probability: 0.0
skip_parent_probability: 0.0
token_mask_probability: 0.0
dataset:
add_eos_token_to_answer: true
add_eos_token_to_prompt: true
answer_column: "Answer\r"
chatbot_author: H2O.ai
chatbot_name: h2oGPT
data_sample: 1.0
data_sample_choice:
- Train
- Validation
limit_chained_samples: false
mask_prompt_labels: true
parent_id_column: None
personalize: false
prompt_column:
- Question
text_answer_separator: <|answer|>
text_prompt_start: <|prompt|>
train_dataframe: data/user/TEST/TEST - train_all.csv
validation_dataframe: data/user/TEST/TEST - valid_all.csv
validation_size: 0.01
validation_strategy: custom
environment:
compile_model: false
find_unused_parameters: false
gpus:
- '0'
huggingface_branch: main
mixed_precision: true
number_of_workers: 8
seed: 5
trust_remote_code: true
use_fsdp: false
experiment_name: magic-labradoodle.1.1.2
llm_backbone: tiiuae/falcon-7b
logging:
logger: None
neptune_project: ''
number_of_texts: 20
output_directory: output/user/magic-labradoodle.1.1.2/
prediction:
batch_size_inference: 0
do_sample: false
max_length_inference: 256
metric: BLEU
metric_gpt_model: gpt-3.5-turbo-0301
min_length_inference: 2
num_beams: 1
num_history: 2
repetition_penalty: 1.2
stop_tokens: ''
temperature: 0.3
top_k: 0
top_p: 1.0
tokenizer:
add_prefix_space: false
add_prompt_answer_tokens: false
max_length: 512
max_length_answer: 512
max_length_prompt: 256
padding_quantile: 1.0
use_fast: true
training:
adaptive_kl_control: true
advantages_gamma: 0.99
advantages_lambda: 0.95
batch_size: 1
differential_learning_rate: 1.0e-05
differential_learning_rate_layers: []
drop_last_batch: true
epochs: 2
evaluate_before_training: true
evaluation_epochs: 1.0
grad_accumulation: 1
gradient_clip: 1.0
initial_kl_coefficient: 0.2
kl_horizon: 10000
kl_target: 6.0
learning_rate: 0.0001
lora: true
lora_alpha: 16
lora_dropout: 0.05
lora_r: 4
lora_target_modules: ''
loss_function: TokenAveragedCrossEntropy
offload_reward_model: false
optimizer: AdamW
ppo_batch_size: 1
ppo_clip_policy: 0.2
ppo_clip_value: 0.2
ppo_epochs: 4
ppo_generate_temperature: 1.0
reward_model: OpenAssistant/reward-model-deberta-v3-large-v2
save_best_checkpoint: false
scaling_factor_value_loss: 0.1
schedule: Cosine
train_validation_data: false
use_rlhf: false
warmup_epochs: 1.0
weight_decay: 0.0
Looks reasonable, hard to debug, we cannot replicate it.
This is strange, but should have no impact if it works
answer_column: "Answer\r"
Also this is not ideal: max_length: 512 max_length_answer: 512 max_length_prompt: 256
With these settings it can happen that only the answer is presented. max_length should also be at least the sum of answer+prompt.
Is it private data or would you be willing to share it?
Thanks for your help. Unfortunately it is private data. I will probably debug myself and let you guys know if I can find some time. Since chat mode and validation mode don't use the same code, it is possible that there is a difference between them.
Could you maybe try training on another backbone such as https://huggingface.co/openlm-research/open_llama_3b
It would allow us to understand if it has something to do with falcon (which we suspect) or your data / our code.
Could replicate it now, but only with falcon on a handful of samples. Will dig more into it.
I have just tried open_llama_3b, doesn't have the same issue.
@aerdem4
Pretty sure I found the issue after going down too many rabbitholes :)
The generation_config in falcon is wrong, I made a PR to their repo here: https://huggingface.co/tiiuae/falcon-7b/discussions/55
To make sure this does not cause issue if wrong, we are now fetching these tokens from the tokenizer. Could you please try this PR to make sure it solves the issue? https://github.com/h2oai/h2o-llmstudio/pull/248
Thanks!
All good now! gg wp
🐛 Bug
For the same question, validation set has different answers compared to asking it on Chat screen. Both have no previous context. Chat screen answers look better.
I have read the code extensively. Different pipelines are used for run_eval and chat but I couldn't find what is really different between them.
To Reproduce
I use falcon-7b and validation set examples have tendency to use a lot of hashtags at the end of answers. Same doesn't happen during chat.