huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.81k stars 1.24k forks source link

Negative KL and generation args #355

Closed zwhe99 closed 1 year ago

zwhe99 commented 1 year ago

Hi! Thanks for your amazing work!

I am running rl_training.py for the machine translation task.

If I use the default generation args:

generation_kwargs = {
    # "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000,
}
output_min_length = 32
output_max_length = training_args.output_max_length
output_length_sampler = LengthSampler(output_min_length, output_max_length)

then the KL loss seems normal

image

However, I found severe hallucinations when I showcase some samples during training. For example, It will continue to generate some unrelated code fragments after the translation:

Source: 如果你生活的城市有多种饮酒文化,那就去那些你不常去的街区的酒吧或酒馆。

Output: If your city has a drinking culture, then go to the bars or pubs in neighbourhoods y$
u wouldn't normally visit. packageio.classmags.aws.lambda.retro.controllers;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.Request;
import

Ground-truth: If you live in a city with a varied drinking culture, go to bars or pubs in neighborhoods you don't frequent.

If I modify the generation args like this:

generation_kwargs = {
    "temperature": 0.7,
    "do_sample": True,
    "num_beams": 1,
    "max_new_tokens": training_args.output_max_length,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": tokenizer.eos_token_id,
}
output_length_sampler = None

Then the KL loss seems abnormal:

image

Fortunately, the showcase is ok:

Source: 如果你生活的城市有多种饮酒文化,那就去那些你不常去的街区的酒吧或酒馆。
Output: If there's a drinking culture where you live, go to the bars or pubs that you wouldn't normally go to.
Ground-truth: If you live in a city with a varied drinking culture, go to bars or pubs in neighborhoods you don't frequent.

But training still fails in the later stages:

image

I'm new to reinforcement learning, and I've read related discussions on other issues, but I'm still not sure what I should do.

zwhe99 commented 1 year ago

@younesbelkada @lvwerra Might you be able to offer any suggestions regarding this issue?

younesbelkada commented 1 year ago

hi @zwhe99 (@lvwerra correct me if I am wrong) The hallucinations you see are related to the fact that you forced eos_token_id to something that does not exist in the vocabulary of the model (here 100_000). You need first to understand what adding a custom eos_token_id does when calling generate: if you force write a different eos_token_id, the generate method will always generate some text until it will encounter that token. Hence, since the token that you have force assigned does not exist, it will continue generating forever, and it hallucinates because it will generate text starting from the "real" eos token (</s>) - if you print the generation with skip_special_tokens=False I am pretty sure that the token </s> will appear in the generated text.

So at test time (when you showcase the generation during training) when doing the generation you shouldn't expose eos_token_id to 100_000- IMO you should have 2 generation kwargs, one that you use for training that includes 100_000 as the eos_token_id to make the training stable (as observed on the stable KL), and one that does not contain that token so that you can showcase your generations during training

zwhe99 commented 1 year ago

@younesbelkada Thank you for your prompt reply! However, I still have a doubt. Although I can change the generation args to print normal model output, is it normal that the model is RL trained on hallucination data?

mekaneeky commented 1 year ago

I am also having a similar issue with -ve KL divergence for a translation task. The model diverges and starts generating gibberish in about 30 iterations.

hecongqing commented 1 year ago

I also have the same problem

vwxyzjn commented 1 year ago

I was able to create a minimal repro at https://github.com/lvwerra/trl/issues/235#issuecomment-1580674096. This issue is kind of a duplicate. Should we close this issue in favor of #235?

lvwerra commented 1 year ago

Sounds good!