mobiusml / hqq

Official implementation of Half-Quadratic Quantization (HQQ)
https://mobiusml.github.io/hqq_blog/
Apache License 2.0
701 stars 69 forks source link

question about fine tune 1bit-quanitzed model #115

Closed zxbjushuai closed 2 months ago

zxbjushuai commented 2 months ago
          Hi,after that (https://github.com/mobiusml/hqq/issues/107) I want to reproduce your paper about HQQ+ but meet some problems when training:

1.Loss doesn't decrease but varies up and down 2.The text generated by the model is very poor I think the reason is that the parameters are not set well enough. Here is my code.May you give me some suggestions? quantization code:

quant_config = HqqConfig(nbits=1, group_size=64 ,quant_zero=False, quant_scale=False) 
max_memory={0: "16GiB", 1: "16GiB", 2: "16GiB", 3: "16GiB"}
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE)
model = AutoModelForCausalLM.from_pretrained(
    model_id,#llama3-8b
    device_map="auto",
    torch_dtype=compute_dtype,#bfloat16
    quantization_config=quant_config,
    max_memory = max_memory,
)

model = prepare_model_for_kbit_training(model)

add adapter code:

config = LoraConfig(
    r=32,#I set this value to 32 because it works better than 16.
    lora_alpha=32,
    target_modules=["q_proj","k_proj","v_proj", "o_proj", "gate_proj","up_proj","down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    use_dora=False,
)
model = get_peft_model(model, config)

and training code:

trainer = Trainer(
    model=model,
    train_dataset=data,#wikitext-2-raw-v1(full),num_rows =36718
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        learning_rate=1e-5,
        logging_steps=20,
        output_dir=adapter_dir,
        bf16 = True,
        save_steps=100,
        save_total_limit=5,
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.is_parallelizable       = False
trainer.is_model_parallel     = False
trainer.place_model_on_device = False
model.config.use_cache = False
trainer.train() 
model.save_pretrained(adapter_dir)

The performance of the model is much better than before fine-tuning, but not good enough. Is it because I didn’t train enough?

Originally posted by @zxbjushuai in https://github.com/mobiusml/hqq/issues/107#issuecomment-2330509939

mobicham commented 2 months ago

Hm I used my custom code for training so I don't know what are difference vs. peft, here are the parameters I used (as far as I remember):

compute_dtype=bfloat16
group_size = 8 
lora-size 32 for ["q_proj","k_proj","v_proj", "o_proj"], lora size 8 for the rest.
n_epochs: 1
learning rate should probably be higher, maybe ~2e-4 and linearly decreased
axis=0 actually was used in the hqq config
zxbjushuai commented 2 months ago

thank you!I will try it.And how do you evaluate the model after fine-tuning?

mobicham commented 2 months ago

https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py#L12

By the way, make sure that the whole model is frozen and only the LoRA layers are trainable

zxbjushuai commented 2 months ago

I don't know how to set ' lora-size 32 for ["q_proj","k_proj","v_proj", "o_proj"], lora size 8 for the rest ' so I set them all 32. Another question is why set group size to 8?does it Influence the result?

mobicham commented 2 months ago

Another question is why set group size to 8?does it Influence the result?

Yes, a lot actually, for 1-bit you need a lower group-size since you lose a lot of accuracy at higher group-sizes.

On a separate note, why do you want to do 1-bit? It's totally useless for modern GPUs, you'd be better off using 2-bit which is also supported for fast inference via the BitBlas backend

zxbjushuai commented 2 months ago

Yes, a lot actually, for 1-bit you need a lower group-size since you lose a lot of accuracy at higher group-sizes.

I see.And I find that model with group-size 8 is 3.2GB and the one with group-size 64 is 6GB.

On a separate note, why do you want to do 1-bit? It's totally useless for modern GPUs, you'd be better off using 2-bit which is also supported for fast inference via the BitBlas backend.

1-bit model is popular and I want to do some reasearch on low-bit models.Anyway,thanks for your adivce🤗

mobicham commented 2 months ago

I see.And I find that model with group-size 8 is 3.2GB and the one with group-size 64 is 6GB.

That can't be right. group-size 8 should be larger than the group-size 64 in terms of GB, since the meta-data will be larger.

1-bit model is popular and I want to do some reasearch on low-bit models.Anyway,thanks for your adivce🤗

I think you are referring to 1.58 not 1-bit. 1.58 is basically 2-bit in terms of inference kernel. 1-bit with HQQ is binary [0,1] not [-,1,0,1]. So 1-bit with hqq will be much worse, and 2-bit will be better than ternary for the same inference speed as 1.58bit on GPUs

zxbjushuai commented 2 months ago

That can't be right. group-size 8 should be larger than the group-size 64 in terms of GB, since the meta-data will be larger.

I made a mistake.The fact is that model with group-size 8 is 6GB and the one with group-size 64 is 3.2GB.

I think you are referring to 1.58 not 1-bit. 1.58 is basically 2-bit in terms of inference kernel. 1-bit with HQQ is binary [0,1] not [-,1,0,1]. So 1-bit with hqq will be much worse, and 2-bit will be better than ternary for the same inference speed as 1.58bit on GPUs

I choose HQQ first because it achieves the true meaning of 1bit,[0,1].I know 1.58bit and appreciate their good job.But unfortunately their paper doesn't give a link with code.So I firstly reproduce other results that are easier to obtain,such as lamma.cpp on 2-bit.1.58-bit is also in my plan and I will compare its result with others at last.

mobicham commented 2 months ago

Note that 1.58 trains the whole model, while HQQ+ only fine-tunes a fraction of the weights via LoRA, so it's not a direct comparison but I understand your motivation :+1: !

zxbjushuai commented 2 months ago

Yes,another repository(https://github.com/xuyuzhuang11/OneBit) also train the whole model use deepspeed and is binary [0,1].But I only have 16GB x 4 so OOM😭 when traing llama3-8b.That is another reason why I use HQQ+.

zxbjushuai commented 2 months ago

Sorry to bothor you again.I did not use all the datasets mentioned in the blog for fine-tuning but I still have some questions. 1.What is the difference between AutoHQQHFModel and HQQModelForCausalLM if I want to load example model? I load example model successfully with HQQModelForCausalLM but encounter an error:

device = torch.device("cuda:0")
model_id = 'mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq' 
model     = HQQModelForCausalLM.from_quantized(model_id, adapter='adapter_v0.1.lora').to(device)
input_text = "Once upon a time"
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)#cuda:0
model.eval()
model.generate(input_ids, max_length=200,do_sample=True)

Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

and load with AutoHQQHFModel then get bad result

model     = AutoHQQHFModel.from_quantized(model_id, adapter='adapter_v0.1.lora',device= device)

model ouput в Once upon a time- isalt... mejorntilombvaicus,ichistrelysing spring eprintist stalig leche fl mh Chartavel whateverms Out elliht point Dcman uniquem in onewh of fufme fixale Searchandmpyes @end RavoomhHecпоre off Ranptaebahcapa ...

2.I mentioned that llama3-8b model with config HqqConfig(nbits=1, group_size=8 ,axis=0,quant_zero=False, quant_scale=False) occupies 6GB memory.But 4-bit quantized model occupies approximately 5GB using AWQ and llama.cpp.How can I get a low-bit model which has high accuracy and takes up little memory?

Thank you in advance.🫡

zxbjushuai commented 2 months ago

I load example model successfully with HQQModelForCausalLM but encounter an error:

Error happens when generating text

zxbjushuai commented 2 months ago

@mobicham

mobicham commented 2 months ago

Hi! Please use AutoHQQHFModel to both save and load, instead of HQQModelForCausalLM if you are using the newer version

Regarding the 2 other questions: I think I know what is the problem: in the latest versions of hqq (>v.0.1.8) we dropped support for quantized scales/zeros and cpu-offloading. That specific model you are trying to load needs quantized scales/zeros since the group-size is very small. Additionally, you need cpu-offloading for the scales/zeros. While the model on disk seems like 3GB, when you actually try to run it, it will only use 1.76GB, so I would recommend you check the VRAM usage not the model size. So if you want to run it, please use

Below the output I just run it:

In [2]: outputs = chat_processor("What is the solution to x^2 - 1 = 0", max_new_tokens=1000, do_sample=False)
User:  What is the solution to x^2 - 1 = 0
Assistant:

The equation $x^2 - 1 = 0$ can be factored as $(x-1)(x+1) = 0$.
You want to find a value of $x$ that makes this true for all values of $x$. This means that either $x=1$ or $-1$, or $x=-1$. So, there are two solutions: $x=\boxed{1}$ and $x=\boxed{-1}$. The answer is: 1
zxbjushuai commented 2 months ago

I see!The version of hqq in my environment is 0.2.0.And I want to know what is the usage of quantized scales/zeros?

That specific model you are trying to load needs quantized scales/zeros since the group-size is very small.

The group-size of the model I am fine-tuning is also very small. Do I need to rollback the hqq version to enable quantized scales/zeros?Actually after fine-tuning using wikitext-2-raw-v1 dataset the perplexity of my model is a bit high(thousands).I don't know if that's normal. Thanks for your great work and help.Expected for your reply.🫡

mobicham commented 2 months ago

Yeah, if you want to use quantized zeros/scales, you need hqq==0.1.8 at most, since after that you can't use quantized scales/zeros. The reason why we dropped support for it in the newer versions is because it made model serialization too complicated and our priority is to make the models fully compatible with transformers serialization. But you can train without the quantized scales/zeros, then quantize them later when you save the model, I think this how I did it as far as I remember.

Can you share your training code, I can run it locally and try to see what's going on

zxbjushuai commented 2 months ago

Sure.Here is my code.I load model and dataset from my disk so you need to modify some code.I hope this doesn't bother you too much. The model is meta-llama/Meta-Llama-3-8B transformers == 4.44.0 hqq == 0.2.0

import torch
from alpaca_lora.utils.prompter import Prompter
from datasets import load_from_disk,load_dataset
from hqq.engine.hf import HQQModelForCausalLM
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.peft import PeftUtils
from hqq.core.quantize import *
from datasets import load_dataset
from transformers import HqqConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from peft import PeftModel,PeftModelForCausalLM
from peft.utils import SAFETENSORS_WEIGHTS_NAME
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
model_id="../models/llama3"
adapter_dir = "llama3_hqq_peft/adapters"
model_save_dir = "llama3_hqq_peft/model"
compute_dtype = torch.bfloat16

quant_config = HqqConfig(nbits=1, group_size=8 ,axis=0,quant_zero=False, quant_scale=False) 
max_memory={0: "16GiB", 1: "16GiB", 2: "16GiB", 3: "16GiB"}
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",#device
    torch_dtype=compute_dtype,
    quantization_config=quant_config,
    max_memory = max_memory,
)

for param in model.parameters():
    param.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token     = tokenizer.eos_token 
tokenizer.padding_side  = "right" 

use_dora = False
config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj","k_proj","v_proj", "o_proj", "gate_proj","up_proj","down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    use_dora=use_dora,
)
model = prepare_model_for_kbit_training(model) #first train
model = get_peft_model(model, config) #first train

data = load_from_disk("../dataset/wikitext-2-raw-v1")['train']
data = data.shuffle(seed=42)
data = data.map(lambda samples: tokenizer(samples["text"], padding=True), batched=True)
trainer = Trainer(
    model=model,
    train_dataset=data,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        learning_rate=2e-4,
        logging_steps=20,#1
        output_dir=adapter_dir,
        bf16 = True,
        save_steps=50,
        save_total_limit=5,
        resume_from_checkpoint=True, 
        neftune_noise_alpha=0.1,
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.is_parallelizable       = False
trainer.is_model_parallel     = False
trainer.place_model_on_device = False

model.config.use_cache = False
trainer.train() 
model.save_pretrained(adapter_dir)
mobicham commented 2 months ago

Something like this, feel free to play with the parameters

#Settings
#pip install hqq==1.8.0
#pip install trl==
#pip install transformers==4.40.0

#OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 ipython3 

######################################################################################
import torch
cache_path    = '' 
model_id      = "meta-llama/Llama-2-7b-hf" 
compute_dtype = torch.bfloat16
device        = 'cuda:0'

#HQQ Quantize
######################################################################################
from transformers import AutoModelForCausalLM, AutoTokenizer
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *

model     = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_path) 

#Quantize the model
from hqq.core.quantize import *
quant_config = BaseQuantizeConfig(nbits=2, group_size=8, quant_scale=False, quant_zero=False, axis=0)
AutoHQQHFModel.setup_model(model)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)

#Add Peft
######################################################################################
from hqq.core.peft import PeftUtils

train_dtype       = torch.torch.float32
atten_lora_params = {'lora_type':'default', 'r':32, 'lora_alpha':32, 'dropout':0.05, 'train_dtype':train_dtype, 'train_bias':True}
mlp_lora_params   = {'lora_type':'default', 'r':8,  'lora_alpha':8,  'dropout':0.05, 'train_dtype':train_dtype, 'train_bias':True}

lora_params       = {'self_attn.q_proj': atten_lora_params,
                     'self_attn.k_proj': atten_lora_params,
                     'self_attn.v_proj': atten_lora_params,
                     'self_attn.o_proj': atten_lora_params,
                     'mlp.gate_proj'   : mlp_lora_params,
                     'mlp.up_proj'     : mlp_lora_params,
                     'mlp.down_proj'   : mlp_lora_params}
#Apply LoRA
PeftUtils.add_lora(model, lora_params)
HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
model.config.use_cache = False

#Dataset 
######################################################################################
from datasets import load_dataset, Dataset
from tqdm import tqdm
import transformers
import numpy as np 
import random

tokenizer.pad_token     = tokenizer.unk_token #tokenizer.eos_token 
tokenizer.padding_side  = "right" 
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
#####################################################################################
#Train
from trl import SFTTrainer

grad_acc   = 1
logging_st = 1
max_steps  = -1
lr         = 1e-5 #1e-5 cosine x 2: 5.5009
batch_size = 1
n_epochs   = 2
max_tokens = 1024 

training_args = transformers.TrainingArguments(
    output_dir='.', 
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=grad_acc,
    learning_rate=lr,
    logging_steps=logging_st,
    num_train_epochs=n_epochs,
    max_steps=max_steps,
    remove_unused_columns=False,
    bf16=True,
    max_grad_norm=1.0,
    save_steps=10000000,
    lr_scheduler_type= "cosine", 
)

#Wrap model to avoid accelerate issues 
class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        return self.model.forward(*args, **kwargs)

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def parameters(self):
        return self.model.parameters()

trainer = SFTTrainer(
    model=WrappedModel(model),
    tokenizer=tokenizer,
    max_seq_length=max_tokens,
    train_dataset=dataset,
    eval_dataset=None,
    peft_config=None,
    args=training_args,
    dataset_text_field="text",
    packing=True,
)

model.is_parallelizable       = False
trainer.is_model_parallel     = False
trainer.place_model_on_device = False
model.train()
trainer.train()

# #Prediction/Eval
# ######################################################################################
from datasets import load_dataset
import torch, time
import numpy as np
from tqdm import tqdm
import gc

tokenizer.add_bos_token = True
tokenizer.add_eos_token = False
PeftUtils.cast_lora_weights(model, dtype=compute_dtype)
model.eval()

#Save lora weights
#PeftUtils.save_lora_weights(model, filename)

def cleanup():
    torch.cuda.empty_cache()
    gc.collect()

#Adapted from https://huggingface.co/transformers/v4.2.2/perplexity.html
def eval_wikitext2(model, tokenizer, max_length=1024, stride=512, verbose=True):
    model.eval()
    tokenizer.pad_token     = tokenizer.eos_token 
    tokenizer.padding_side  = "right" 
    tokenizer.add_eos_token = False

    dataset   = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt')

    encodings['input_ids'] = encodings['input_ids'].to('cuda')

    lls, t = [], []
    for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose):
        begin_loc  = max(i + stride - max_length, 0)
        end_loc    = min(i + stride, encodings['input_ids'].size(1))
        trg_len    = end_loc - i  
        input_ids  = encodings['input_ids'][:,begin_loc:end_loc]
        target_ids = input_ids.clone()
        target_ids[:,:-trg_len] = -100 #ignore context 

        t1 = time.time()
        with torch.no_grad():
            log_likelihood = model(input_ids, labels=target_ids).loss * trg_len
        torch.cuda.synchronize()
        t2 = time.time()
        t.append((t2-t1))
        lls.append(log_likelihood)

        del input_ids, target_ids

    ppl       = np.round(float(torch.exp(torch.stack(lls).sum() / end_loc)), 4)
    pred_time = np.round(np.mean(t), 3)
    if(verbose):
        print('perplexity', ppl)
        print('time', str(pred_time) + '  sec')

    del encodings
    cleanup()

    return {'perplexity':ppl, 'prediction_time':pred_time}

print('perplexity',eval_wikitext2(model, tokenizer, max_length=1024, stride=512, verbose=True))
zxbjushuai commented 2 months ago

Thank you. I will try it🤗

zxbjushuai commented 2 months ago

Hi,@mobicham,I successfully fine-tuned the llama3-8b model and achieved great results with my code.Here is the result

perplexity 23.0261
time 1.063  sec
{'perplexity': 23.0261, 'prediction_time': 1.063}

And then I find that using different methods to load the model results in different perplexities. The first method is:

quant_config = HqqConfig(nbits=1, group_size=8 ,axis=0,quant_zero=False, quant_scale=False) 
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",#device
    torch_dtype=compute_dtype,
    quantization_config=quant_config,
)
model = PeftModel.from_pretrained(model, adapter_1,adapter_name= "adapter_1")
model.set_adapter("adapter_1")
eval_wikitext2(model,tokenizer)

{'perplexity': 23.0261, 'prediction_time': 1.063} The second method is:

device = "cuda:0"
quant_config = HqqConfig(nbits=1, group_size=8 ,axis=0,quant_zero=False, quant_scale=False) 
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=compute_dtype,
    quantization_config=quant_config,
)
AutoHQQHFModel.save_quantized(model,model_save_dir)
del model
new_model = AutoHQQHFModel.from_quantized(model_save_dir,device = device)
new_model = PeftModel.from_pretrained(new_model, adapter_1,adapter_name= "adapter_1")
new_model.set_adapter("adapter_1")
eval_wikitext2(new_model,tokenizer,data)

{'perplexity': 9823.5264, 'prediction_time': 0.97} Two questions: 1.Is it because AutoHQQHFModel cannot be used to load the model quantized by AutoModelForCausalLM when loading lora adapter?And I found the reason why the perplexity was high before. The reason was that every time I evaluated the perplexity, I used AutoHQQHFModel to load the quantized model instead of using AutoModelForCausalLM to quantize the original model. 2.How do I correctly save the AutoModelForCausalLM quantized model and load model with adapter?Or I can just roll back the version and use your code.It seems that transformers and hqq libraries do not work well together.

I find your great work in transformers issue: https://github.com/mobiusml/hqq/pull/93#issuecomment-2230466271 Can it be used to save/load HQQ-quantized model with HF transformers now?By the way,maintaining a repo is really hard work,I think.

mobicham commented 2 months ago

I would recommend you use hqq's PeftUtils https://github.com/mobiusml/hqq/tree/master?tab=readme-ov-file#peft-training instead of transformers peft. 1- After training, use PeftUtils.save_lora_weights to save the adapters 2- You can then either load the quantized model via AutoHQQHFModel.from_quantized or quantize on the fly 3- After the quantized model is loaded, you can use PeftUtils.load_lora_weights If you do it this way, it should work without any problem. That's how I have been doing it for all the models.

Regarding the save/load, we actually have a fully working PR that fully supports saving/loading, it's not merged yet: https://github.com/huggingface/transformers/pull/33141

zxbjushuai commented 2 months ago

Thank you very much for your advice🫡. I am browsing that PR now.It is convenient that transformers fully supports saving/loading. Hope your work goes well.And I will follow your code👌.

zxbjushuai commented 2 months ago

pip install trl==

One last question,which trl version are you using?Mine is trl==0.9.6 hqq==0.1.8 transformers==4.40.0,which may lead to some errors. tokenizer.pad_token = tokenizer.unk_token #tokenizer.eos_token reports:

ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.

and tokenizer.pad_token = tokenizer.eos_token reports:

TypeError: WrappedModel.train() takes 1 positional argument but 2 were given
zxbjushuai commented 2 months ago

I solve that problem by changing WrappedModel code:

class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        return self.model.forward(*args, **kwargs)

    def train(self,*args): #here I change
        self.model.train(*args)

    def eval(self):
        self.model.eval()

    def parameters(self):
        return self.model.parameters()

But sadly,my device's memory is not big enough so OOM because the model is only loaded on one GPU.I wish you can complete the save/load of hqq-quantized model in transformers👍.Distributed running is very helpful for consumer-grade GPUs🫡.

mobicham commented 2 months ago

You need ~16GB for 1024 tokens. You can reduce the VRAM usage by using AdamW 8-bit optimizer for example, or reduce the tokens from 1024 to 512

zxbjushuai commented 2 months ago

Great!

zxbjushuai commented 2 months ago

😭,When saving the checkpoint, trainer.train() reported an error: HQQLinearLoRA.state_dict() got an unexpected keyword argument 'destination'. So I modified the HQQLinearLoRA.state_dict() function:

    def state_dict(self,destination,prefix,keep_vars,*args):#(self):
        return {
            "lora_A": self.lora_A.data,
            "lora_B": self.lora_B.data,
            "scaling": self.scaling,
            "bias": self.bias,
        }

The trainer can now save/load the checkpoints correctly. But I don't know if this has any impact on the fine-tuning results.

mobicham commented 2 months ago

Oh thanks for mentioning that :thinking: , which version of transformers are you using?

zxbjushuai commented 2 months ago

transformers == 4.44.2 trl == 0.9.6 hqq == 0.1.8 I have tried transformers == 4.40.0 before.And I got the same error.

mobicham commented 2 months ago

Alright, will take a look a look at it later :thinking:

zxbjushuai commented 2 months ago

Did you get this error when running it? I think it may be a problem with the trl version. I think the checkpoint result is a little strange after modifying state_dict(). The model.safetensors file appears in the checkpoint folder. And we know that the hqq quantized model cannot be saved in safetensor format for the time being. So I think there is a high possibility of error. The error occurs in site-packages/torch/nn/modules/module.py line 1939.

        for name, module in self._modules.items():
            if module is not None:
              ->module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)

And my pytorch version is 2.4.0.Thanks for your help👍!

mobicham commented 2 months ago

So I just tried the example I shared with you: https://github.com/mobiusml/hqq/blob/master/examples/lora/hqq_plus.py and saving/loading works just fine with PeftUtils. If ti's a problem with trl the best is to open an issue in their repo since I don't really know what's going on with it. If you use PeftUtils from hqq as recommend it should work without any issue:

zxbjushuai commented 2 months ago

OK,I know.I'm so sorry for taking up so much of your time and energy.I mean,this is not a problem with the hqq library. I think it should be a problem with the incompatibility of the trl version. -- Can you tell me your trl library version?🫣

Thank you very much! You are really helpful!

mobicham commented 2 months ago

I just did pip install trl and used the latest version !

zxbjushuai commented 2 months ago

Oh,there must be something wrong with my environment.I'll find a solution by myself.At least I can fine-tune successfully using the Trainer.Thank you anyway🫡!