huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.02k stars 1.57k forks source link

Model encapsulation #1902

Closed tomekrut closed 1 month ago

tomekrut commented 3 months ago

System Info

Hi guys, I have some complex models where I use just part of sub-models of transformers e.g. Below I used AutoModelForCausalLM.from_pretrained() but normally it would be something like LlamaModel.from_pretrained()

import sys
import torch
from typing import List, Optional, Union, Any, Dict, Tuple
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer
import bitsandbytes as bnb
from datasets import load_dataset
import torch.nn as nn

class ModelWrap(nn.Module):
    def __init__(self, model):
        super(ModelWrap, self).__init__()
        self.model = model

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            cache_position: Optional[torch.LongTensor] = None):
        return self.model(
                input_ids = input_ids,
                attention_mask = attention_mask,
                position_ids = position_ids,
                past_key_values = past_key_values,
                inputs_embeds = inputs_embeds,
                labels = labels,
                use_cache = use_cache,
                output_attentions = output_attentions,
                output_hidden_states = output_hidden_states,
                return_dict = return_dict,
                cache_position = cache_position)

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
qcfg = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=qcfg, device_map="balanced")

# You can comment in / out below to change my scenario
# model = ModelWrap(model)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
dataset = load_dataset("imdb", split="train[:1%]")

lora_config = LoraConfig(
    r=16, # 16
    lora_alpha=32, # 32
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    inference_mode=False,
    bias="none")

model = prepare_model_for_kbit_training(model)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

trainer = SFTTrainer(
    model=peft_model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=1024,
    tokenizer=tokenizer,
)

trainer.train()
peft_model.save_pretrained("some_test")

Inside the ModelWrap() there is plenty of stuff but I just wanted to simplify it. Whether I use the SFT trainer or my own.. the GPU memory utilization explodes and I always end up with GPU OOM. Without the wrapper I consume on 8B model (8 bit) 32GB. Once I wrap the model 80GB is not enough. I have A100 80GB.

export CUDA_VISIBLE_DEVICES=0
python above_script.py

Can you please comment on that? What I am doing wrong etc...

Who can help?

No response

Information

Tasks

Reproduction

Script added

Expected behavior

It should still work within 32GB memory threshold.

BenjaminBossan commented 3 months ago

I can reproduce the error. My suspicion is that it's somehow related to the Trainer class. When I wrote a vanilla PyTorch training loop, I saw the same memory consumption with and without wrapping.

I also tried passing the non-PEFT model directly to SFTTrainer and also passing peft_config=lora_config, as SFTTrainer knows how to deal with PEFT, but this made little difference.

Another thing I tried is to first create the PEFT model, and then wrap it with ModelWrap. This reduces memory consumption.

As I'm not an expert when it comes to (SFT)Trainer, it's hard for me to tell what goes wrong. There are a lot of isinstance checks within the code, so I wonder if there could be a relation there.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.