huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
15.72k stars 1.52k forks source link

Inference with different LoRA adapters in the same batch does not use the correct module_to_save classifier #1960

Open saeid93 opened 1 month ago

saeid93 commented 1 month ago

System Info

Python 3.11.9 transformers==4.40.2 peft==0.11.2

Who can help?

@BenjaminBossan I'm interested in using Inference with different LoRA adapters in the same batch feature, with a separate last layer classifier for each LoRA adapter. However, during the inference when having multiple requests destined to different adapter the peft library uses the active adapter for every request rather than the appropriate LoRA weight. I should note that this problem only happens for ModuleToSave layers and the other layers (e.g. the base models) use the correct LoRA weights per each request.

This is because the per-request LoRA weights are not passed to ModulesToSaveWrapper class and the inference is always done using the active_adapter module to save:

https://github.com/huggingface/peft/blob/f2b6d13f1dbc971c7653aa65e82822ea2d84bb38/src/peft/utils/other.py#L264

The solution would be to pass the adapter_names all the way down to the forward function of ModuleToSaveWrapper **kwargs and write a similar logic as https://github.com/huggingface/peft/blob/f2b6d13f1dbc971c7653aa65e82822ea2d84bb38/src/peft/tuners/lora/layer.py#L327

for sending sub_batches of similar adapters together to each appropriate classifier.

However, I see that you are excluding the special_peft_forward_args: https://github.com/huggingface/peft/blob/f2b6d13f1dbc971c7653aa65e82822ea2d84bb38/src/peft/peft_model.py#L761

possibly to avoid interfering with the base_models forward function e.g. https://github.com/huggingface/transformers/blob/5f841c74b62754f186a8c06a684d491524b7bc03/src/transformers/models/vit/modeling_vit.py#L813

I was able to solve this by modifying the mentioned functions but since it is a bug I think it can also be considered for being solved in the upstream or be mentioned in the documentation as another caveat. @stevhliu

Information

Tasks

Reproduction

The case is pretty similar to the documentation example when having a classifier module on top. The code returns no error but I was able to notice this by observing difference in accuracies and tracking the root in the peft library as I mentioned above. Please let me know if more information is needed.

Expected behavior

Each adapter using their own classifier rather than active classfier.

BenjaminBossan commented 1 month ago

Thanks a lot for reporting this error and your great investigation. Indeed, this should ideally work out of the box. You mentioned:

I was able to solve this by modifying the mentioned functions

Could you please share the code to achieve this?

saeid93 commented 1 month ago

Glad to be of any help! Please find the code below, it is just a hack to dynamically patch the modifications to the library. The rest of the code is just using the below class and functions rather than Peft and transformers classes. I have marked changes with HEREs on the code.

BTW, I'm happy to investigate further for fixing this with a pull request, however, it will take some time. If there is a timeline for fixing it then I leave it to you.

from typing import Any, Optional, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from peft.peft_model import PeftModel
from transformers.modeling_outputs import ImageClassifierOutput
from transformers import ViTForImageClassification, MobileViTForImageClassification
from functools import partial

class PeftModelFixed(PeftModel):
    def forward(self, *args: Any, **kwargs: Any):
        """
        Forward pass of the model.
        """
        with self._enable_peft_forward_hooks(*args, **kwargs):
            # HERE removed this to avoid mixing
            # kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
            return self.get_base_model()(*args, **kwargs)

class ViTForImageClassificationFixed(ViTForImageClassification):
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs # HERE added kwargs
    ) -> Union[tuple, ImageClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        # HERE this was changed with having adapters
        logits = self.classifier(sequence_output[:, 0, :], adapter_names=kwargs["adapter_names"])

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

def peftforward(self, *args, **kwargs):
    if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
        return self.original_module(*args, **kwargs)

    # HERE changed to support LoRA
    adapter_names = kwargs["adapter_names"]
    kwargs = {}
    batch = args[0]
    unique_adapters = set(adapter_names)
    sub_batch_indices_list = []
    for adapter in unique_adapters:
        sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])

    results = [0 for i in range(len(batch))]
    for i, active_adapter in enumerate(unique_adapters):
        sub_batch = batch[sub_batch_indices_list[i]]
        output = self.modules_to_save[active_adapter](*(sub_batch,), **kwargs)
        for index, j in enumerate(sub_batch_indices_list[i]):
            results[j] = output[index]
    return torch.stack(results)

def change_forward_dynamically(model: PeftModel):
    # HERE model is passed here to dynamically change the last layer
    model.classifier.forward = partial(peftforward, model.classifier)
    return model
BenjaminBossan commented 1 month ago

Thanks a lot. I would gladly accept a PR for this fix, that would be fantastic. It should probably be much easier to add than your fix, as it could be fixed directly where needed instead of patching.

There is no strict timeline, we just had a release, so the next one would still be a bit in the future.

saeid93 commented 1 month ago

Awesome, I'll work on it when I get a chance.

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

saeid93 commented 1 week ago

This is pending approval of #1990. Sending this to remove the automatic stale mark.