bilibili / Index-1.9B

A SOTA lightweight multilingual LLM
Apache License 2.0
877 stars 48 forks source link

Colab微调无法保存模型:AttributeError: 'IndexForCausalLM' object has no attribute 'save_checkpoint' #22

Open duanyu opened 3 months ago

duanyu commented 3 months ago

使用trl的SFTTrainer + Lora微调,无法保存模型。 训练配置的相关代码如下:

deepspeed_config = {
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        },
        "allgather_partitions": True,
        "allgather_bucket_size": 5e8,
        "overlap_comm": True,
        "reduce_scatter": True,
        "reduce_bucket_size": 5e8,
        "contiguous_gradients": True,
        "round_robin_gradients": True
    }
}

peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

sft_config = SFTConfig(output_dir='models/index-1.9b-ft',
                       per_device_train_batch_size=4,
                       gradient_accumulation_steps=4,
                       per_device_eval_batch_size=4,
                       num_train_epochs=3,
                       learning_rate=1e-4,
                       report_to='tensorboard',
                       bf16=True,
                       max_seq_length=1024,
                       deepspeed=deepspeed_config,
                       logging_steps=10,
                       eval_steps=10,
                       save_steps=10,
                       save_on_each_node=True,
                       )

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    args=sft_config,
    tokenizer=tokenizer,
    peft_config=peft_config,
)

报错信息:AttributeError: 'IndexForCausalLM' object has no attribute 'save_checkpoint'

环境: 机器为colab的免费T4;transformers==4.41.2;trl==0.9.4;peft==0.11.1

asirgogogo commented 3 months ago

Try using the ‘save_model ’ to save checkpoint

zhyang2226 commented 3 months ago

使用deepspeed或torchrun指令执行代码,不要直接用python