huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.66k stars 27.16k forks source link

Training ByT5 for next response generation #23778

Closed salokr closed 1 year ago

salokr commented 1 year ago

Hi,

I am trying to train a ByT5 model for text2text generation specifically, given previous chat history the objective is to produce a response for the input. I understand that I can use decoder-only models for the task, but we need to use the byte-level information which we will be using in the future. For training purposes, I have obtained a dataset for fine-tuning and used the following configuration:

--model_name_or_path google/byt5-base \
    --do_train \
    --do_eval \
    --do_predict \
    --output_dir ./t5-base_50k_tast10 \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=16 \
    --predict_with_generate \
    --eval_steps 1 \
    --greater_is_better True \
    --load_best_model_at_end True\
    --logging_steps 4 \
    --metric_for_best_model bleu_2 \
    --num_train_epochs 100 \
    --save_steps 1 \
    --save_total_limit 10 \
    --evaluation_strategy epoch \
    --save_strategy epoch \
    --max_source_length 1000 \
    --max_target_length 200 \
    --learning_rate 5e-5 \

My code to fine-tune looks like the following:

config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir)

tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
    use_fast=True,
    truncation_side='left')
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_args.model_name_or_path,
    config=config,
    cache_dir=model_args.cache_dir,
)

embedding_size = model.get_input_embeddings().weight.shape[0]
if(len(tokenizer)>embedding_size):
    model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

max_target_length = data_args.max_target_length 
padding = "max_length" if data_args.pad_to_max_length else False
def preprocess(text):
    ... # some preprocessing code
def preprocess_function(examples):
    ... #call preprocess above and tokenize
    model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding='longest', truncation=True, return_tensors="pt")
    labels = tokenizer(text_target = targets, max_length=max_target_length, padding='longest', truncation=True, return_tensors="pt")
    ...

if(training_args.do_train):
    train_dataset = train_dataset.map(preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, desc="Running tokenizer on train dataset",remove_columns=column_names, load_from_cache_file=False)
if(training_args.do_eval):        
    eval_dataset = val_dataset.map(preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, desc="Running tokenizer on validation dataset", remove_columns=column_names, load_from_cache_file=False)
if(training_args.do_predict):
    test_dataset = test_dataset.map(preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, desc="Running tokenizer on prediction dataset",remove_columns=column_names, load_from_cache_file=False)

label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=label_pad_token_id, pad_to_multiple_of=8 if training_args.fp16 else None)
metric = evaluate.load("bleu")
def postprocess_text(preds, labels):
    ...#post process stuff
    return preds, labels

def compute_metrics(eval_preds):
    ... #get bleu and other metrics
    return result

training_args.generation_max_length = training_args.generation_max_length if training_args.generation_max_length is not None else data_args.val_max_target_length
training_args.generation_num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset = train_dataset if training_args.do_train else None,
    eval_dataset = eval_dataset if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics if training_args.predict_with_generate else None,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
)
if training_args.do_train:
    checkpoint = None
    if(training_args.resume_from_checkpoint is not None):
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    trainer.save_model()
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

However, the problem with the above code is after a lot of fine-tuning the model generates text which is repeated again and again and sometimes copies from the input or generates responses that are not relevant or related to the input. I have tried contrastive search, beam search, etc. also but the response generated by the model is still gibberish. Any suggestions on how to improve ByT5's capability to do the task? As I understand, T5-based models (or ByT5) perform well on many seq2seq tasks such as Text2SQL, etc. so they should at least generate relevant responses to the input for this task too.

Please let me know, any suggestions you have. @ArthurZucker @younesbelkada

I am also attaching some sample responses generated by the model.

Screenshot 2023-05-25 at 10 24 34 PM
ArthurZucker commented 1 year ago

Hey! Thanks for reporting, however urgent this is, please refrain from pinging as many people as that. All the questions related to how to train or improve my training should be asked on the forum, as they are not bugs and the community is more adept to help you there.