microsoft / LoRA

Code for loralib, an implementation of "LoRA: Low-Rank Adaptation of Large Language Models"
https://arxiv.org/abs/2106.09685
MIT License
10.3k stars 657 forks source link

Cannot use lora for a pre-trained model #123

Open HelloWorldLTY opened 1 year ago

HelloWorldLTY commented 1 year ago

Hi, I tried to replace our model's linear layer with lora.linear. However, it seems that all of the components in this module cannot be used for finetuning.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[17], line 26
     17 valid_loader = prepare_dataloader(
     18     valid_data_pt,
     19     batch_size=config.batch_size,
   (...)
     22     drop_last=False,
     23 )
     25 if config.do_train:
---> 26     train(
     27         model,
     28         loader=train_loader,
     29     )
     30 val_loss, val_mre = evaluate(
     31     model,
     32     loader=valid_loader,
     33 )
     34 elapsed = time.time() - epoch_start_time

Cell In[16], line 74, in train(model, loader)
     71     metrics_to_log.update({"train/dab": loss_dab.item()})
     73 model.zero_grad()
---> 74 scaler.scale(loss).backward()
     75 scaler.unscale_(optimizer)
     76 with warnings.catch_warnings(record=True) as w:

File /notebooks/lib/python3.9/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File /notebooks/lib/python3.9/site-packages/torch/autograd/__init__.py:197, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    192     retain_graph = create_graph
    194 # The reason we repeat same the comment below is that
    195 # some Python versions print out the first line of a multi-line function
    196 # calls in the traceback and some print out the last line
--> 197 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    198     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    199     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How to address this problem? Thanks a lot.

iopwsy commented 1 year ago

I have the same problem, but I fix it. It seems you need set loss.requires_grad = True and it seems normal. However, the results is not as good. I don't know why. Maybe the loss here is for basic model rather than LoRA?

HelloWorldLTY commented 12 months ago

Hi, thanks for your exaplaniation. I use another approch to address it.

But my training time is longer.

After using LoRA, it is: image

No Lora, it is: 449a4192ccc77f4c7746dac3fa2f012

I have no ideas.

qizhaoaoe commented 5 months ago

Hi, thanks for your exaplaniation. I use another approch to address it.

But my training time is longer.

After using LoRA, it is: image

No Lora, it is: 449a4192ccc77f4c7746dac3fa2f012

I have no ideas.

Could you share your method to solve this problem? Thanks!

HelloWorldLTY commented 5 months ago

Hi, my approach is to replace all linear module with lora module at first, then try:

lora.mark_only_lora_as_trainable(model)
trainable_params = []
if True:
    # if True:
    #     lora_state_dict = torch.load(model_args.lora_path)
    #     logger.info(f"Apply LoRA state dict from {model_args.lora_path}.")
    #     logger.info(lora_state_dict.keys())
    #     model.load_state_dict(lora_state_dict, strict=False)
    trainable_params.append('lora')

if len(trainable_params) > 0:
    for name, param in model.named_parameters():
        if name.startswith('deberta') or name.startswith('roberta'):
            param.requires_grad = False
            for trainable_param in trainable_params:
                if trainable_param in name:
                    param.requires_grad = True
                    break
        else:
            param.requires_grad = True