huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.6k stars 1.64k forks source link

Request for adding the lora implementation for Conv1d rather than transormers.utils.Conv1d #2241

Open HelloWorldLTY opened 5 days ago

HelloWorldLTY commented 5 days ago

Feature request

Hi, I found that Lora does not support the model with torch.nn.Conv1d as convolution layers, which limits the use-case for models pre-trained with this class (for example, Enformer). I wonder if it is possible to add an implementation based on this class.

Motivation

To finetune enformer.

Your contribution

If you need I can open a PR.

BenjaminBossan commented 4 days ago

Thanks for opening this feature request. We cannot drop support for transormers Conv1D as it is required for certain models like gpt2. However, we can consider adding support for torch Conv1d on top. If that layer can re-use the same LoRA Linear implementation as transformers Conv1D does, it should be fairly easy. If you have some code to enable this, feel free to open a (draft) PR.

HelloWorldLTY commented 1 day ago

Thanks, would you please provide any hints for me to implement it? e.g., do you think it is ok if I directly replace every transformers.conv1d with torch.conv1d? Thanks a lot.

BenjaminBossan commented 1 day ago

First of all, let me preface by saying that I'm not sure if we can just use torch Conv1d or if it won't work. This would need to be tested. Second, here is the crucial part of the code:

https://github.com/huggingface/peft/blob/3f9ce553e21569e21269b5ba91d7390f7229199a/src/peft/tuners/lora/layer.py#L1262-L1269

This is the logic where we check the base layer type and decide that we want to apply a LoraLayer (in this case the Linear LoRA layer).

Theoretically, we can just replace the isinstance check by:

elif isinstance(target_base_layer, (Conv1D, nn.Conv1d)):

to match torch Conv1d. Maybe you can give this a try and check if it works for your use case before opening a PR.

However, we should not replace transformers Conv1D because this is needed for some models to work.

HelloWorldLTY commented 1 day ago

Hi, I tried you recommendation method, but I received a new error:

File /home/tl688/.conda/envs/evo/lib/python3.11/site-packages/torch/nn/modules/linear.py:98, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
     96 self.in_features = in_features
     97 self.out_features = out_features
---> 98 self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
     99 if bias:
    100     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))

TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got NoneType"

Since my conv1d has kernel size large than 1, it is not trival to make the transformation. I will try other softwares to see if it works.

BenjaminBossan commented 1 day ago

If you provide the code to reproduce the error, I can take a look.

HelloWorldLTY commented 1 day ago

Hi, thanks a lot. I am trying to implement the lora mode of Enformer:

https://github.com/lucidrains/enformer-pytorch

Here is my code to have the lora mode:

def get_lora(model, lora_config = None, train = False): 
    """
    Applies Low-Rank Adaptation (LoRA) to the model.
    This function integrates LoRA modules into specified layers of the model, enabling parameter-efficient 
    fine-tuning. If `train` is True, it sets the LoRA parameters and specific layers in the base model 
    to be trainable. Otherwise, it freezes all parameters.
    Args:
        lora_config (LoraConfig, optional): Configuration for LoRA. If None, uses a default configuration.
        train (bool): Whether the model is being prepared for training.
    """
    if lora_config is None:
#         lora_config = LoraConfig(
#             target_modules=r"(?!separable\d+).*ConvBlock|.*to_q|.*to_v|EnformerTransformerBlock\.\d+\.1\.fn\.1|EnformerTransformerBlock\.\d+\.1\.fn\.4",
#         )

        lora_config =LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["linear", "to_q", "to_k", "to_v", "conv"],
            lora_dropout=0.01,
        )
    model = get_peft_model(model, lora_config) # get LoRA model
    print(model)
    if train:
        for params in model.base_model.model.model.embedding.conv_tower.parameters():
            params.requires_grad = True
        if model.base_model.model.model.embedding.transformer_tower:
            for params in model.base_model.model.model.embedding.transformer_tower.parameters():
                params.requires_grad = True
        model.print_trainable_parameters()

    else:
        for params in model.parameters():
            params.requires_grad = False
    return model

LoRA works well for linear, toq,tok, tov, but the conv represents the nn.Conv1d mode, and I faced this error. The conv layer has kernel size as 5.

BenjaminBossan commented 11 hours ago

I could make a bit more progress:

import torch
from peft import LoraConfig, get_peft_model
from enformer_pytorch import Enformer

model = Enformer.from_pretrained("EleutherAI/enformer-official-rough", device_map=0)
model = get_peft_model(model, LoraConfig(target_modules=["linear", "to_q", "to_k", "to_v", "conv"]))
seq = torch.randint(0, 5, (1, 196_608)).to(0) # for ACGTN, in that order (-1 for padding)
output = model(seq)

The only changes I had to make were to this line:

-    elif isinstance(target_base_layer, Conv1D):
+    elif isinstance(target_base_layer, (Conv1D, nn.Conv1d)):

and this line:

-        elif isinstance(base_layer, nn.Conv2d):
+        elif isinstance(base_layer, (nn.Conv2d, nn.Conv1d)):

However, the forward pass will fail because of mismatched shapes. I think the nn.Conv1d module cannot be simply replaced by a Linear layer, unlike transformers Conv1D. It probably needs its own LoRA layer type.