kathrinse / be_great

A novel approach for synthesizing tabular data using pretrained large language models
MIT License
276 stars 46 forks source link

Support for LoRA? #34

Open TimS-ml opened 1 year ago

TimS-ml commented 1 year ago

Hi:

Thank you very much for open-sourcing this project! I found in your be_great/great.py, self.efficient_finetuning support lora. I've come across a few bugs that I may need help with.

[1] GReaT.load_from_dir() will lead to the state dict mismatch.

Missing key(s) in state_dict: "transformer.wte.weight" ...
Unexpected key(s) in state_dict: "base_model.model.transformer.wte.weight" ...

[2] net.sample(n_samples, k=50) returns

AttributeError: 'GPT2LMHeadModel' object has no attribute 'generation_config'

Thanks

sebffischer commented 1 year ago

I just faced the same problem. The problem is that the load_from_dir() method does not create the correct model.

This is a workaround:

from be_great import GReaT

great = GReaT('distilgpt2')

# Define LoRA Config
lora_config = LoraConfig(
    r=16,  # only training 0.16% of the parameters of the model
    lora_alpha=32,
    target_modules=[
        "c_attn"
    ],  # this is specific for gpt2 model, to be adapted
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,  # this is specific for gpt2 model, to be adapted
)
# add LoRA adaptor
great.model = get_peft_model(great.model, lora_config)
great.model.print_trainable_parameters()

great.model.load_state_dict(torch.load("model.pt"))

import json
# Load attributes
with open("config.json", "r") as f:
    attributes = json.load(f)

# Set all attributes
for k, v in attributes.items():
    setattr(great, k, v)