huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132k stars 26.29k forks source link

RWKV produces erroneous output after fine-tuning and model.eval() enabled; inference during model.train() works correctly #25937

Closed Vectorrent closed 11 months ago

Vectorrent commented 1 year ago

System Info

Who can help?

@ArthurZucker @gante

Information

Tasks

Reproduction

I am not using the Huggingface Trainer. I am using a very basic training loop, originally forked from AITextGen, which uses LightningAI's automatic optimization. The bulk of that fine-tuning code can be found here. This code works correctly for the fine-tuning of other models, like GPT-Neo and Cerebrus-GPT. Something about RWKV v4 (169m/430m) is different.

To reproduce, you would have to implement my training logic (which isn't terribly "custom" or "complicated" at all), then toggle between eval/train modes, while performing inference - to see the difference. Alternatively, perhaps you could train in your own way, and toggle between eval/train... just to let me know if the problem is with my training code? I don't think it is.

I have tried both LoRA and traditional fine-tuning. Both have the same results. I have tried all manner of learning rate adjustments, weight decay, batch size... but hyperparameters don't seem to fix this problem. Nor would I really expect it to; if the problem can be fixed by toggling between eval/train modes, then I would expect that the problem lies in the HF implementation. I spoke to BlinkDL about this (the creator of RWKV), and he said it sounds like a bug in the HF inference code.

Expected behavior

RWKV is unable to produce coherent output after fine-tuning, when self.model.eval() is enabled. If the model is set to self.model.train(), then the output is as expected.

Take this sample data, which I've fine-tuned RWKV v4 430m on:

¶6064923396882833153:> I have been through everything.
¶251161711099248641:> Am with you...there buddy! lol Like that save? jk
¶1623339977514240393:> Nice, this gives me hope for the sake of being fully onboarded into our own reality. Lol

Within <1000 training steps, a fine-tuned model (with self.model.train() enabled) will be capable of producing output like this:

¶389988567908488489:> What is the point of this video?
¶747257279748767804:> Just be more careful
¶389988567908488489:> What is the point of this video?
¶747257279748767804:> The point is to make you think.
¶389988567908488489:> What is the point of this video?
¶747257279748767804:> Because it is a video. A video is a video.

However, that same model - with self.model.eval() enabled - will produce gibberish, like this:

¶ ¶ [")')\")", [ [ [ [3]3] ¶**A1\new!C$',!C$',!C$',!C$')!C$3\ndraw (((4.5+4@    
¶  Which 'YNUMC" is (((78740833245160745 WCHAR) + "Enter " +   
¶,iple!", [vi@ 1400! 0.post\n:> (((694509,632072,A"," - - -", - - -))) [r "'¶5,",   

I would expect RWKV to perform better in self.model.eval() mode, not worse than self.model.train(). Clearly, the model is training correctly, and it is learning; something about eval mode completely break generation, though.

gante commented 1 year ago

Hi @LuciferianInk 👋

Calling model.eval() turns off things like dropout and per-batch normalization and, as you wrote, should improve things at inference time. I'm afraid that without a short reproducer to load the fine-tuned model and reproduce the problem quickly, there is little we can do -- our bandwidth is limited, so we need your help too :)

Vectorrent commented 1 year ago

Okay, thanks for the response. I might need a couple of days, but I'll try to put something together for you. I'll probably use Docker. Let me know if that's an issue.

Vectorrent commented 1 year ago

Hi @gante,

As requested, I have published a small project to reproduce this issue. It will load RWKV-v4 430m model, attach a LoRA adapter, and quickly run inference. You may choose to use the provided Docker configs, or not; both Docker or vanilla Python should work. Further instructions are in the README file.

I did not recreate the training loop, because you didn't ask for it (nor am I certain that training was the problem). If you'd like to see the training code, I linked to it above.

Thank you for your time and attention to this matter. Please let me know if you need anything else from me.

Vectorrent commented 11 months ago

Well, I've learned a few things, which make me lean towards this being a "quirk in the model," rather than an actual problem with Transformers' inference.

  1. I was able to train RWKV/rwkv-4-169m-pile by using PEFT, without running into this issue at all. However, both 430m and 1b5 immediately run into it.
  2. I suspect the model is overfitting. Perhaps larger RWKV models are more sensitive to fine-tuning with less data, or data that contains certain kinds of patterns? Although, I have several gigabytes of training data... and less than 5% of it actually looks like the examples RWKV is overfitting on.
  3. I can negate this problem by setting PEFT's "r" argument to 1. I think this makes sense; with so few trainable parameters, the model is forced to learn more general representations, rather than memorizing the numbers and patterns you see above. Of course, the problem is... you can't encode very much information into such small weight matrices.
  4. I tried a full and regular fine-tuning on the 430m model, and the issue is not present there. Thus, LoRA is the problem.
  5. I still haven't found a great solution, but I'm sure I will continue to revisit this problem, until I've landed on something.

Not sure if the issue is still worth tracking here, at this point. I really think I'm just fighting with the challenge of training an RNN, versus the ease of a transformer. I'll leave it to the maintainers to decide if they'd like to close the issue or not.

LysandreJik commented 11 months ago

Thanks for sharing your insights! Might be interesting for @pacman100 who has worked on PEFT (no actionable items right now AFAIU)

Vectorrent commented 11 months ago

Okay! I think we finally landed on a solution. It started with an explanation of the various RWKV modules from Google Bard:

Key module

The key module takes the input query and context as input and produces a representation that is used to retrieve the most relevant key-value pairs from the RWKV memory. This is done by transforming the input query and context into a common space, where they can be compared to the keys in the memory. The key module is typically implemented as a neural network, with parameters that are learned during training.

Value module

The value module takes the retrieved key-value pairs as input and produces a representation that is used to update the output query. This is done by transforming the key-value pairs into a common space, where they can be combined to produce an update to the output query. The value module is typically implemented as a neural network, with parameters that are learned during training.

Receptance module

The receptance module controls how much of the update produced by the value module is applied to the output query. This is done by multiplying the update by a scalar value, which is called the receptance. The receptance module is typically implemented as a single layer neural network, with parameters that are learned during training.

Long story short, I spent some time experimenting with asymmetric ranks and alpha on the different modules, and eventually landed on some settings that work. At this point, I'm tired of fighting with it, and ready to move on.

I'll be sure to close this issue in a few days, after I'm positive the problem was resolved.

Vectorrent commented 11 months ago

Well, there is no doubt that RWKV is more difficult to work with than a transformer, but I've finally landed on some functional settings. At the end of the day, it required a larger training set, less weight decay, SWA, and a lot of other optimizations. But mostly - avoid training the "value", "output", and "head" modules - and you'll have a better time.

Going to close this issue now.

LysandreJik commented 11 months ago

Thanks for sharing, very insightful @LuciferianInk :)