EricLBuehler / xlora

X-LoRA: Mixture of LoRA Experts
Apache License 2.0
152 stars 7 forks source link

does xlora train not support llama2? #28

Open crossxxd opened 4 months ago

crossxxd commented 4 months ago

I've trained xlora with mistral 7b base model, it works fine. However, when switching base model to llama2 7b, it encountered an error.

This is my code for training.

model = AutoModelForCausalLM.from_pretrained(
    "/root/autodl-tmp/Llama/Llama-2-7b-chat-ms",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.enable_input_require_grads()
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
model_created = xlora.add_xlora_to_model(
    model=model,
    xlora_config=xlora.xLoRAConfig(
        model.config.hidden_size,
        base_model_id="/root/autodl-tmp/Llama/Llama-2-7b-chat-ms",
        xlora_depth=4,
        device=torch.device("cuda"),
        adapters={
            "adapter_1": "/root/autodl-tmp/xlora_demo/llama_output/checkpoint-5",
            "adapter_2": "/root/autodl-tmp/xlora_demo/llama_output/checkpoint-10",
        },
    ),
    verbose=True,
)
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/Llama/Llama-2-7b-chat-ms", trust_remote_code=True)
args = TrainingArguments(
    output_dir="/root/autodl-tmp/xlora_test/output/Llama-2-7b-chat-mix-lora",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=1,
    save_steps=10,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True
)
max_seq_length = 4096
data_name = "mlabonne/guanaco-llama2-1k" # Dataset
dataset = load_dataset(data_name, cache_dir="/root/autodl-tmp/dataset_cache", split="train")
trainer = SFTTrainer(
    model=model_created,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=False,
    args=args,
    train_dataset=dataset,
    dataset_text_field="text",
)
trainer.train()

And error is

  File "/root/miniconda3/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 360, in train
    output = super().train(*args, **kwargs)
  File "/root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
    return inner_training_loop(
  File "/root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 2118, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py", line 3045, in training_step
    self.accelerator.backward(loss)
  File "/root/miniconda3/lib/python3.10/site-packages/accelerate/accelerator.py", line 2001, in backward
    loss.backward(**kwargs)
  File "/root/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Thanks.

leonardxie commented 3 months ago

same error +1

maximus-21 commented 2 months ago

@crossxxd can you share your training code for mistral 7b base model ? I have been able to put the llama model on the training, however the training is very slow with no decrease in training loss. Your code for mistral 7b might help.

maohaos2 commented 2 months ago

@maximus-21 Hey, could you let me know how to fix the code to get llama model working? Thanks a lot!

maximus-21 commented 2 months ago

@maohaos2 can you share your code ?