huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
15.83k stars 1.53k forks source link

Missing modules in prompt-based PEFT when re-loading model #2043

Open martin-wey opened 2 weeks ago

martin-wey commented 2 weeks ago

System Info

python 3.10.10, transformers 4.44.2, peft 0.12.0

Who can help?

@BenjaminBossan @sayak

Information

Tasks

Reproduction

The following is a simplified script to reproduce the bug. I have experienced the same issue using transformers.Trainer. The fine-tuning using PEFT and p-tuning/prompt tuning works perfectly. However, when reloading the model from a saved PEFT checkpoint for generation, some modules are missing. In turn, the model does not generate expected content.

Basic script to save the adapter:

from transformers import AutoModelForCausalLM
from peft import PromptEncoderConfig, get_peft_model, AutoPeftModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, peft_config)
model.save_pretrained("meta-llama-ptuning")
model.print_trainable_parameters()
print(model)

Output:

trainable params: 1,151,232 || all params: 8,031,412,480 || trainable%: 0.0143
PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
  )
  (prompt_encoder): ModuleDict(
    (default): PromptEncoder(
      (embedding): Embedding(20, 4096)
      (mlp_head): Sequential(
        (0): Linear(in_features=4096, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): ReLU()
        (4): Linear(in_features=128, out_features=4096, bias=True)
      )
    )
  )
  (word_embeddings): Embedding(128256, 4096)
)

Load the PEFT model from checkpoint (method 1):

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
peft_config = PromptEncoderConfig.from_pretrained("meta-llama-ptuning")
model = get_peft_model(model, peft_config)
print(model)

Load the PEFT model from checkpoint (method 2):

model = AutoPeftModelForCausalLM.from_pretrained("meta-llama-ptuning")
print(model)

Output: The MLP part of the prompt encoder is missing

PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
  )
  (prompt_encoder): ModuleDict(
    (default): PromptEncoder(
      (embedding): Embedding(20, 4096)
    )
  )
  (word_embeddings): Embedding(128256, 4096)
)

Perhaps something's wrong here: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/tuners/p_tuning/model.py#L82

Expected behavior

I believe the entire prompt encoder should be reloaded from the PEFT checkpoint. When I reload the model using the aforementioned methods, it generates content that is close to the base model, meaning the prompt encoder is not loaded properly.

BenjaminBossan commented 2 weeks ago

Thanks for reporting this issue.

First of all, please don't load a model using get_peft_model, this is for creating new models. Always use from_pretrained to load models. Second, please ensure that the same dtype is used when loading the model. Initially, you use bfloat16 but when you called model = AutoPeftModelForCausalLM.from_pretrained("meta-llama-ptuning"), you forgot to specify the dtype.

That said, even with these adjustments, as well as ensuring that the model is in eval mode, I could reproduce the error, i.e. there is a small discrepancy after loading (in my tests, mean abs diff of logits was ~0.02, depending on model). However, this discrepancy is not due to the missing module being loaded. When you check these lines:

https://github.com/huggingface/peft/blob/679bcd8777fc8215208bc46b7f54f1f4061791ae/src/peft/utils/save_and_load.py#L144-L155

You can see that the prompt embedding is saved as part of the state_dict. This is an optimization, because for pure inference, since parameters are fixed, this output does not change anyway, so there is no need to load the mlp_head.

Still, I'm not sure where the difference comes from and will investigate further. Just wanted to share some insights I had so far.

PS: Interestingly, for facebook/opt-125m, I did not find any discrepancy. However, when I checked meta-llama/Meta-Llama-3-8B, bigscience/bloomz-560m, Qwen/Qwen2-1.5B, and microsoft/Phi-3.5-mini-instruct, they all had a small difference.

martin-wey commented 2 weeks ago

First of all, please don't load a model using get_peft_model

I may have pasted a wrong snippet. I'm having the same issue when first loading the base model then use PeftModel to load the model+adapter.

You can see that the prompt embedding is saved as part of the state_dict. This is an optimization, because for pure inference, since parameters are fixed, this output does not change anyway, so there is no need to load the mlp_head.

Got it, thanks!

Second, please ensure that the same dtype is used when loading the model. Initially, you use bfloat16

Even so, the PEFT-tuned model's responses to input prompts are very similar to the base model. Besides, I'm also having troubles with Phi-3.5-mini-128k-instruct and CodeQwen1.5-7B-Chat. However, LoRA-based PEFTs (including QLoRA and DoRA) work just fine. The entire fine-tuning phase with prompt-based tunings seems normal, i.e., good validation loss almost matching LoRA-tuned models.

Thanks for your quick reply. I will also keep trying with other LLMs.

martin-wey commented 2 weeks ago

@BenjaminBossan here's some more details about the forward pass of the PEFT and base models. I hope it helps :)

I checked whether the input goes through all the modules of the network using hooks.

# ...
model = AutoPeftModelForCausalLM.from_pretrained(
    "../runs/codellama/CodeLlama-7b-Instruct-hf_conala_p-tuning_3e-3/checkpoint-198/", 
    torch_dtype="bfloat16", 
    device_map="auto"
)

def forward_hook(module, input, output):
    print(f"Module: {module.__class__.__name__}")
    print(f"Input: {input}")
    print(f"Output: {output}")
    print("-" * 50)

hook_word_embeddings = model.word_embeddings.register_forward_hook(forward_hook)
hook_prompt_embeddings = model.prompt_encoder['default'].embedding.register_forward_hook(forward_hook)

for sample in test_set:
    tokenized_sample = tokenizer.apply_chat_template(
        sample["messages"],
        return_tensors="pt",
    ).to(model.device)

    output = model(tokenized_sample)

Output:

Module: Embedding
Input: (tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
           526,   263,  8444, 20255, 29889,    13, 29966,   829, 14816, 29903,
          6778,    13,    13,  4563,   680,   278,  1819,   411,  1021,  6611,
           310,  1023,  8600,   421, 29881, 29896, 29952,   322,   421, 29881,
         29906, 29952,   518, 29914, 25580, 29962]], device='cuda:0'),)
Output: tensor([[[ 0.0069,  0.0031, -0.0013,  ...,  0.0003, -0.0031, -0.0026],
         [ 0.0170, -0.0242,  0.0211,  ..., -0.0048, -0.0002,  0.0183],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         ...,
         [ 0.0142, -0.0182, -0.0065,  ..., -0.0210,  0.0099, -0.0132],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         [-0.0237, -0.0297,  0.0014,  ...,  0.0427, -0.0008, -0.0029]]],
       device='cuda:0', dtype=torch.bfloat16)

The input never goes through model.prompt_encoder.default.embedding. I am not sure that's the expected behaviour. Therefore, I tried the same thing with the base model by registering a hook for embed_tokens:

# ...
model = AutoModelForCausalLM.from_pretrained(
    "codellama/CodeLlama-7b-Instruct-hf", 
    torch_dtype="bfloat16", 
    device_map="auto"
)
hook_embeddings = model.model.embed_tokens.register_forward_hook(forward_hook)
# ...

Output (identical):

Module: Embedding
Input: (tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
           526,   263,  8444, 20255, 29889,    13, 29966,   829, 14816, 29903,
          6778,    13,    13,  4563,   680,   278,  1819,   411,  1021,  6611,
           310,  1023,  8600,   421, 29881, 29896, 29952,   322,   421, 29881,
         29906, 29952,   518, 29914, 25580, 29962]], device='cuda:0'),)
Output: tensor([[[ 0.0069,  0.0031, -0.0013,  ...,  0.0003, -0.0031, -0.0026],
         [ 0.0170, -0.0242,  0.0211,  ..., -0.0048, -0.0002,  0.0183],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         ...,
         [ 0.0142, -0.0182, -0.0065,  ..., -0.0210,  0.0099, -0.0132],
         [-0.0312, -0.0112, -0.0510,  ...,  0.0070,  0.0281, -0.0294],
         [-0.0237, -0.0297,  0.0014,  ...,  0.0427, -0.0008, -0.0029]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<EmbeddingBackward0>)

Basically it seems like the prompt_encoder is bypassed during the forward pass. That's why I'm getting the same output when generating code using both PEFT and base models.

BenjaminBossan commented 2 weeks ago

Thanks for digging deeper into this.

The input never goes through model.prompt_encoder.default.embedding. I am not sure that's the expected behaviour.

Yes, this is expected, as it relates to the optimization I mentioned above. The prompt embeddings to be prefixed are precomputed, therefore, the embedding's forward is never called.

That's why I'm getting the same output when generating code using both PEFT and base models.

I cannot observe this, for me the outputs of the loaded PEFT model are very close to the outputs of the original PEFT model, but different enough that generations start different at some point.

Investigating this further sent me down a rabbit hole and I think I have figured out the issue. To cut it short, the issue appears to be that the individual outputs of sending the same input repeated 10 times through the MLP does not equal the output of sending it through the MLP once. Let me illustrate:

import torch
torch.manual_seed(0);
device = 0
input_size = 128
hidden_size = 32
output_size = 64
layers = [
    torch.nn.Linear(input_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, hidden_size),
    torch.nn.ReLU(),
    torch.nn.Linear(hidden_size, output_size),
]
mlp_head = torch.nn.Sequential(*layers)
mlp_head.to(device).eval();
x = torch.randn(20, input_size).to(device)

# x repeats 10x along the batch dimension
x = x.repeat(10, 1, 1)
# output with all 10 identical samples
out0 = mlp_head(x)
# output with only 1 of the samples
out1 = mlp_head(x[:1])

for i in range(10):
    # this should be 0 but it is 2.0311928139449265e-08
    print((out0[i:i+1] - out1).abs().mean().float().item())

When the involved sizes are small enough, the difference is actually 0, which may explain why opt-125m showed no problems but the bigger LLMs did.

So how does this translate to p-tuning? Let's check these lines which are executed during training time:

https://github.com/huggingface/peft/blob/679bcd8777fc8215208bc46b7f54f1f4061791ae/src/peft/peft_model.py#L658-L663

https://github.com/huggingface/peft/blob/679bcd8777fc8215208bc46b7f54f1f4061791ae/src/peft/peft_model.py#L695

You see that we repeat the same input batch_size times and then send it through the prompt encoder.

Now, when we load the model, we go through a slightly different code path:

https://github.com/huggingface/peft/blob/679bcd8777fc8215208bc46b7f54f1f4061791ae/src/peft/peft_model.py#L693

Here, we just take the output of a single sample and repeat it batch_size times. In theory, that should be the same thing but as I showed above, there are slight differences.

To address this, I created the following patch:

@@ -692,7 +692,13 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
                 if peft_config.inference_mode:
                     prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
                 else:
-                    prompts = prompt_encoder(prompt_tokens)
+                    prompt_tokens = (
+                        self.prompt_tokens[self.active_adapter]
+                        .unsqueeze(0)
+                        .expand(1, -1)
+                        .to(prompt_encoder.embedding.weight.device)
+                    )
+                    prompts = prompt_encoder(prompt_tokens).repeat(batch_size, 1, 1)
             return prompts

     def get_nb_trainable_parameters(self) -> tuple[int, int]:

With this patch, the error vanishes for me. For completeness, here is the script I used to check it:

import torch
from transformers import AutoModelForCausalLM
from peft import PromptEncoderConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel

inputs = torch.arange(10).view(-1, 1).to(0)

model_id = "meta-llama/Meta-Llama-3-8B"
#model_id = "bigscience/bloomz-560m"
#model_id = "Qwen/Qwen2-1.5B"
#model_id = "microsoft/Phi-3.5-mini-instruct"
#model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map=0,
)
peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
model = get_peft_model(model, peft_config)
model.eval();

torch.manual_seed(0)
with torch.inference_mode():
    output_peft = model(inputs).logits
    gen_peft = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)

model.save_pretrained("/tmp/peft/2043")
del model
torch.cuda.empty_cache()

model = AutoPeftModelForCausalLM.from_pretrained("/tmp/peft/2043", device_map=0, torch_dtype=torch.bfloat16)
# using `model = AutoModelForCausalLM.from_pretrained(...); model = PeftModel.from_pretrained(...)` also works

torch.manual_seed(0)
with torch.inference_mode():
    output_loaded = model(inputs).logits
    gen_loaded = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)

torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)

If you can confirm that this patch solves your original issue, I will create a PR to fix this in PEFT.

martin-wey commented 2 weeks ago

Thanks a lot for the detailed explanation.

From my end, I am still experiencing both models generating exactly the same content for a given prompt, even with this fix. It makes no sense to me as the training phase using p-tuning or prompt tuning works properly, with validation loss decreasing.

I use the following dataset: https://huggingface.co/datasets/neulab/docprompting-conala. The model learns to generate 1-2 lines of code for a given instruction.

I implemented a TrainerCallback after each epoch to generate code for a given test sample to rule out potential issues when reloading the model from a checkpoint.

class GenerateAfterEpochCallback(TrainerCallback):
    def __init__(self, test_example, tokenizer):
        self.test_example = test_example
        self.tokenizer = tokenizer

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs['model']
        tokenizer = self.tokenizer

        model.eval()

        inputs = tokenizer.apply_chat_template(
            self.test_example["messages"][:-1] # remove assistant's solution, 
            return_dict=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)

        with torch.no_grad():
            output_ids = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=128,
            )

        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        print(f"\nGenerated text after epoch {state.epoch}:\n{generated_text}\n")

        model.train()

        return control

Regardless of the PEFT hyperparameters configuration, the model generates the same content all the time, which is also identical to the base model. I tried a few things that all result in identical outputs (see below):

Output: base_model The output makes sense for the prompt, but is not what is expected from a fine-tuned model. It's like p-tuning and prompt tuning have zero impact.

Interestingly, when generating without attention_mask, the generated content and the input prompt change completely: without_attn_mask

This is weird. I used prompt-based tunings with PEFT in may/june 2023 for another project (with a similar setup), and I never had such issue.

BenjaminBossan commented 1 week ago

Thanks for the additional information. I was afraid that something more could be going on as you wrote that the results are very different.

For me to further assist in this, it would be very helpful if you could share a bit more information. Ideally, you could upload the adapter to HF so that I can try it myself. Please also provide the code that you use to check if the outputs are as expected or not. If you cannot share the checkpoint, would it be possible to share the training code instead?

Another thing that could be going on is that there have been some changes to transformers recently that interfere with prompt tuning methods in PEFT. Therefore, it would be helpful if you could test if your issue resolves when using an older transformers version. Maybe you can figure out which one you used back then? If you downgrade transformers, you may also have to downgrade PEFT to a version that corresponds to that time. If you can determine that your checkpoint works with version X but not Y, this would greatly increase the chances of figuring out what's going wrong.

PS: Just checked, something like transformers v4.29 and v0.3.0 would correspond to the time you mentioned.

martin-wey commented 1 week ago

@BenjaminBossan thanks again. Please ignore my last reply if you have seen it. I am still working on it to make sure I provide you 100% accurate information. I'll reply asap.

BenjaminBossan commented 4 days ago

@martin-wey Do you have any updates?

martin-wey commented 2 days ago

@BenjaminBossan Yes, as you suggested I compared fine-tuning using prompt tuning with 1. Transformers v4.29 / Peft v0.3.0 and 2. Transformers v4.44.2 / Peft v0.12.0. I used deepseek-coder-6.7b-instruct.

Summary of my findings:

Here's the code compatible for both versions:

import argparse

import torch

from datasets import load_dataset
from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorForLanguageModeling,
    TrainerCallback,
    set_seed
)

from collator import CustomDataCollatorForCompletionOnlyLM

def load_model_and_tokenizer(args):
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    peft_config = PromptTuningConfig(
        task_type=TaskType.CAUSAL_LM,
        prompt_tuning_init=PromptTuningInit.RANDOM,
        num_virtual_tokens=20, 
        tokenizer_name_or_path=args.model_name_or_path,
    )

    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = model.config.eos_token_id

    return model, tokenizer

def main(args):
    dataset = load_dataset("neulab/docprompting-conala")
    model, tokenizer = load_model_and_tokenizer(args)

    def tokenize(example):
        prompt = f"{tokenizer.bos_token}\n"
        prompt += f"### Instruction:\n{example['nl']}\n"
        prompt += f"### Response:\n{example['cmd']}\n<|EOT|>"

        model_inputs = tokenizer(prompt, truncation=True, max_length=128, padding="max_length")

        return model_inputs

    tokenized_dataset = dataset.map(
        tokenize,
        batched=False,
        remove_columns=[cn for cn in dataset["train"].column_names if cn not in ["input_ids", "attention_mask"]],
    )

    training_args = TrainingArguments(
        output_dir=args.run_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        logging_strategy="steps",
        bf16=True,
        logging_steps=1,
        save_total_limit=10,
        load_best_model_at_end=True,
        report_to="wandb" if args.use_wandb else "none"
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        tokenizer=tokenizer,
        data_collator=CustomDataCollatorForCompletionOnlyLM("### Response", tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]
    )

    trainer.train()
    trainer.model.save_pretrained(f"{args.run_dir}/best_model_checkpoint")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", default="deepseek-ai/deepseek-coder-6.7b-instruct", type=str)
    parser.add_argument("--output_dir", default=".", type=str)

    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--patience", type=int, default=2)

    parser.add_argument("--learning_rate", type=float,  default=3e-3)
    parser.add_argument("--lr_scheduler_type", type=str, default="linear")

    parser.add_argument("--use_wandb", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    set_seed(args.seed)

    args.model_name = args.model_name_or_path.split('/')[-1]
    args.run_dir = f"{args.output_dir}/{args.model_name}_conala_prompt-tuning_new/"
    main(args)

Data collator (extension of the trl completion only data collator, that keeps the EOS token in the labels):

class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)

        # ensure the last tokens is taken into account for loss computation
        # otherwise the model may never stop generating at inference
        batch["labels"][:, -1] = batch["input_ids"][:, -1]

        return batch

Learning curves are drastically different: image

Eval curves: image