johnsmith0031 / alpaca_lora_4bit

MIT License
533 stars 84 forks source link

High perplexity while lower loss after LoRA finetuning (how?) #140

Closed alex4321 closed 1 year ago

alex4321 commented 1 year ago

Hi. During my experiments with model architecture change, I noticed that I was getting gibberish results while generating texts, despite the Tensorboard showing me that the loss was becoming lower over time.

A short time after that I came up with the following test with the raw Vicuna fine-tuning, without my experimental stuff: https://github.com/alex4321/alpaca_lora_4bit/blob/9398fd39cb2369a9aed2d9e5a73909c882d1c894/notebooks/test-finetune-perplexity.ipynb

Basically, what does it do:

Now a bit to the details

At first, about perplexity implementations. I used the formula from here: https://huggingface.co/docs/transformers/perplexity

So basically for each token - we know the correct following token, it's logproba (logits estimated by the model, than logsoftmax) and so we can calculate the following expression: image

In a few sources I have also seen another formula: image

But they're equal due to the following conversion: image

So the implementation should be okay.

Next one - what should be our loss function? Categorical cross-entropy, which is basically the same as image if I am not misunderstanding?

This is basically proportional to the term inside exponent in perplexity, which means lower categorical cross entropy should mean lower perplexity and vice versa, isn't it? We're calculating them on the same texts, so length normalization is not an issue.

But instead of this, I see:

So loss became lower, while perplexity instead became higher

Libraries versions:

alpaca_lora_4bit==0.1.2
transformers==4.28.1
peft==0.3.0
alex4321 commented 1 year ago

Guess I will dive deeper into training process to check if loss calculated correctly a bit later.

alex4321 commented 1 year ago
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, self.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)

Okay, I guess I see it:

shift_logits = logits[..., :-1, :].contiguous()

so we cut logits to only ones for which we know the label...

shift_labels = labels[..., 1:].contiguous()

but I did forget that labels is shifted

alex4321 commented 1 year ago

Changed the dataset & perplexity calc function logic regards the topic, rechecking.

p.s. really strange. I guessed I remembered it from some old huggingface generative models (did nothing with generative things for a long time, so it well possible I remembered some outdated bugs), but all the models I checked do shifting inside, at least nowadays.

alex4321 commented 1 year ago

Yeah, after removing manual shifting from the dataset (and adding shifting to perplexity calculation function):

Original perplexities: [6.494434, 5.0919514, 4.490959] Initial LoRA perplexities: [6.494434, 5.0919514, 4.490959] Initial LoRA loss: 1.6666971445083618 Trained LoRA perplexities: [2.724882, 1.6267716, 1.964895] Trained LoRA loss: 0.721564769744873

So it works as expected.

Well, at least debugging the issue in a separate "environment" was easier than with all the architecture changes of my pet project model.

Fixed notebook version: https://github.com/alex4321/alpaca_lora_4bit/blob/test_finetuning/notebooks/test-finetune-perplexity.ipynb

alex4321 commented 1 year ago

Guess I need to update corresponding autotest as well, because it means now it test that model is capable of remembering stuff, not proper generation tuning.