huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
15.71k stars 1.52k forks source link

Ability to switch between different adapter types and mix different adapter types #1006

Closed kovalexal closed 9 months ago

kovalexal commented 11 months ago

Feature request

Hi! As far as I know, currently, PEFT allows to load and infer only adapters of the same type.

So, we cannot e.g. load LoRA and LoHa for Stable Diffusion and switch between them during inference (we are forced to unload all the adapters of the first type, we may lose some progress, for example mixture of several LoRA adapters).

Also, it's a pretty common thing for Stable Diffusion ecosystem to load and mix different adapters together to get some unique style (so, adapters may be of different types also). This ability to mix different adapters together can be partially addressed with the new API for enabling multiple adapters during inference, but we are currently limited to a single adapter type.

To sum up, it would be great to have these features:

  1. Ability to load different adapter types and switch between them without the need to unload anything
  2. Ability to mix different adapter types into new adapters (probably for adapters that allow us to calculate get_delta_weight)

Motivation

These features would be super useful for those who are using diffusers implementation of Stable Diffusion and PEFT for LoRAs/LoHas/etc. in production, where losing some progress or reloading checkpoints is undesirable.

As far as I know, Hugging Face diffusers do not allow to perform those types of manipulations with checkpoints, also PEFT has broader support for different adapters for SD&SDXL.

From my perspective, these features may also be useful for LLMs, probably an ability to switch between different adapters or mix different adapters may be beneficial for some downstream tasks.

Your contribution

I would be happy to help you with implementing these features, but it is not clear to me right now, how it could be achieved with current library architecture (so we probably need to discuss your view on how it should be implemented).

BenjaminBossan commented 11 months ago

Thanks for bringing up this feature request. Indeed, it's also something that we have discussed internally and which we agree would be good to have but will be hard to implement.

Partly, the effect can be achieved already by merging and unloading one adapter type before adding the next, but of course this makes it impossible to unmerge the former, so we should find a better solution.

I think that full support for this feature will be very difficult to achieve because of various assumptions that are made (implicitly) throughout the code base. Some types of adapters may even not be possible to combine. However, if we restrict this feature to a subset of adapters (e.g. LoRA, LoHa, LoKr, maybe IA³), it could be possible.

One very big issue I see right now is the way we handle the control of which adapters are active. Just as an example, here is the meat of the forward method of LoRA Linear:

https://github.com/huggingface/peft/blob/eced2edff8aa1c7c0ca50b86df4603b4586bedc2/src/peft/tuners/lora/layer.py#L275-L283

As is, this layer "controls" which are the active adapters by applying them one after the other. However, if we want the active adapters to be lora1, loha1, lora2, this cannot work, because loha1 needs to be applied before lora2 (unless all adapters are commutative, which I don't think we can assume). So to implement this, we would require a completely different approach.

One example of how this may work is how I prototyped it here. The general idea is to implement all adapters purely through forward hooks (pre and post). There is a wrapper class (like PeftModel) which is responsible for handling the adapters, instead of having the control on the adapters themselves. Activating adapters becomes a matter of registering their hooks, and deactivating them only requires removing the hooks. This makes handling of active adapters simpler and more transparent IMO.

Of course, making such a switch would be an enormous task for a code base such as PEFT, so I'm not sure if it's not too late at this point. Maybe there are other ideas how to achieve this, I'm very open to suggestions.

kovalexal commented 10 months ago

Thanks for sharing your view on this problem.

Your prototype looks neat! I am wondering, how it compares to the current PEFT approach in terms of performance and reliability. I thought that attaching hooks to torch modules is more suitable for debugging / profiling purposes, but it provides so much flexibility for those adapter approaches

Some types of adapters may even not be possible to combine. However, if we restrict this feature to a subset of adapters (e.g. LoRA, LoHa, LoKr, maybe IA³), it could be possible.

Based on my experience, I see that for Stable Diffusion the most used adapters right now are probably LoRA, LoHa, and LoKr. So this subset seems to cover most of the functionality needed for those models.

Do you by chance know, is it a common thing to mix different adapters for LLMs? I suppose that there might exist some language-specific adapters, which could be combined with some domain-specific knowledge adapters, but I am now sure, what are the most common adapter types for that kind of tasks..

As is, this layer "controls" which are the active adapters by applying them one after the other. However, if we want the active adapters to be lora1, loha1, lora2, this cannot work, because loha1 needs to be applied before lora2 (unless all adapters are commutative, which I don't think we can assume). So to implement this, we would require a completely different approach.

Hmmm, if I understand it correctly, if we restrict a mixture to a subset of LoRA, LoHa, and LoKr - these adapters are fully commutative (if we apply them at the same time), each of them in general just provides a $\Delta W$ and $b$ for the base layer weights, am I right?

$Layer(x, \{{lora}_1, {loha}_1, {lora}_2\}) = (Wx + b) + (\Delta W_{lora_1}x + b_{lora_1}) + (\Delta W_{loha_1}x + b_{loha_1}) + (\Delta W_{lora_2}x + b_{lora_2}) = (W + \Delta W_{lora_1} + \Delta W_{loha_1} + \Delta W_{lora_2})x + (b + b_{lora_1} + b_{loha_1} + b_{lora_2})$

Also, we can see that webui provides a similar approach - it just accumulates total $\Delta W$ and $b$ for all the requested adapters.

Maybe there are other ideas how to achieve this, I'm very open to suggestions.

Your approach leads me to an idea - why don't we try to mimic this behavior in the existing code base? The key thing that needs to be done - is to separate actual adapter delta modules from layers that we are trying to modify.

We might transition to something like this (pseudocode):


class DeltaProviderProtocol(nn.Module):
    def __init__(self):
        ...

    def get_delta_weight(self):
        ...

    def forward(self, x):
        ...

class LinearLoraDeltaProvider(DeltaProviderProtocol):
    def __init__(self):
        self.lora_A = nn.Linear(...)
        self.lora_B = nn.Linear(...)
        ...

    def get_delta_weight(self):
        return self.lora_B.weight @ self.lora_A.weight

    def forward(self, x):
        return self.lora_B(self.lora_A(x))

class LinearLohaDeltaProvider(DeltaProviderProtocol):
    def __init__(self):
        self.hada_w1 = nn.Parameter(...)
        self.hada_w2 = nn.Parameter(...)
        ...

    def get_delta_weight(self):
        return self.hada_w1 * self.hada_w2

    def forward(self, x):
        return self.get_delta_weight() @ x.T

class LinearLokrDeltaProvider(DeltaProviderProtocol):
    def __init__(self):
        self.lokr_w1 = nn.Parameter(...)
        self.lokr_w2 = nn.Parameter(...)
        ...

    def get_delta_weight(self):
        return torch.kron(self.lokr_w1, self.lokr_w2)

    def forward(self, x):
        return self.get_delta_weight() @ x.T

class LinearAdapterLayer(nn.Linear):
    def __init__(self, ...):
        self.adapters: Dict[str, DeltaProviderProtocol] = nn.ModuleDict({})
        ...

    def merge(self):
        for adapter in self.active_adapters:
            self.weight += self.adapters[adapter].get_delta_weight()

    def forward(self, x):
        result = F.linear(...)
        for adapter if self.active_adapters:
            result += self.adapters[adapter](x)
        return result

So, instead of modifying base models layers with the ones that support only single adapter (peft.tuners.lora.Linear, peft.tuners.lora.Conv2d, peft.tuners.loha.Linear, peft.tuners.loha.Conv2d, etc.), we may modify them with the abstraction which allows to incorporate any delta layer which implements DeltaProviderProtocol protocol. It should allow us to easily switch between different adapters on the go. Of course, we might end up with something like DeltaModel, which will handle enabling/disabling, adding/removing, and storing adapter layers independently of the base model (also it may allow us to easily switch base model in case it's needed).

In my opinion, it can be done without messing up with the existing code base (current adapters may exist alongside this new implementation, which will just utilize those DeltaProviderProtocol successors). The only downside I see is that we would need to handle existing checkpoints keys to be able to load them into those new adapter layers.

In terms of enabling / mixing different adapters DeltaAdapter would also allow us to achieve this pretty easily. To get a mixture we might just sum up several $\Delta W$ for different adapters and store the result parameters inside a new full diff delta layer. Also, we could incorporate something like SVD matrix decomposition and convert the full diff delta layer to something like LoRA delta layer. The only downside I see with this approach is that storing full diff delta layers might be quite expensive in terms of GPU memory.

So, in general, this approach would work pretty much the same as your prototype but will allow us to reuse the existing code base and stay flexible. What do you think?

BenjaminBossan commented 10 months ago

I am wondering, how it compares to the current PEFT approach in terms of performance and reliability.

I haven't tested that, since it was more of a proof of concept for me, but in general, I don't think that the performance characteristics should be different, as the same amount of computation is being carried out during forward/backward. It might be faster when initializing and switching adapters, but I'm not sure how noticeable that is.

I thought that attaching hooks to torch modules is more suitable for debugging / profiling purposes

I think at one point the PyTorch docs said so, but not anymore, so I think hooks are a good way of implementing this type of feature. (Note that register_module_forward_hook still has an explicit warning that it should only be used for profiling/debugging, but that shouldn't concern us)

Do you by chance know, is it a common thing to mix different adapters for LLMs?

I'm not aware of this being a common pattern, but the field is moving so fast, so who knows what's true tomorrow.

Hmmm, if I understand it correctly, if we restrict a mixture to a subset of LoRA, LoHa, and LoKr - these adapters are fully commutative (if we apply them at the same time)

Yes, but I'd be hesitant to "lock" the feature in a way that it only works with commutative modifications.

Another solution that I discussed with @pacman100 is that we could refactor LoRA Linear etc. to use the "base_layer pattern", similar to the change in #994. This would allow us to nest multiple LoRA/LoHa/LoKr layers and they would be applied correctly. Right now, this wouldn't work because the adapter layer at the highest level would just call:

https://github.com/huggingface/peft/blob/56556faa17263be8ef1802c172141705b71c28dc/src/peft/tuners/lora/layer.py#L300

which ignores the forward call of all the lower nested layers. I think that approach has the chance to be the easiest to implement, but we need to try how it works in practice before we can be really sure. It also would require some extra work to match the state dicts in that case.

If, instead, we went ahead with the idea of adding a completely new adapter type which allows to combine multiple types of adapters, I think I'd like to try working with the hooks approach, which is a bit more flexible than requiring the adapter to provide a delta_weight. For instance, we could add IA³ to the mix, which otherwise wouldn't work.

Note, however, that even with hooks, not all cases are covered. If we need to make a change in the middle for forward instead of at the beginning or end, hooks cannot cover that. Similarly, if we want to avoid calling forward of the adapted layer completely, hooks also don't support that unfortunately. The latter is necessary for instance for modules_to_save, where we want to call forward on a copy of the original weights (the way I "solved" this in my POC is that forward is called twice and the one using the original weights is discarded, which of course is wasteful).

github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

BenjaminBossan commented 9 months ago

Note: This will be implemented via #1069 or a spin-off of that PR.