huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.3k stars 26.62k forks source link

Add support for pruning whole layers in transformer models. #17475

Closed chrisdt1998 closed 2 years ago

chrisdt1998 commented 2 years ago

Feature request

Dear HuggingFace team,

In the ViT Model folder (namely modeling_vit.py), there is an option to prune the attention heads of a model. However, at the moment, if I want to prune a whole layer, I get an error due to the dense layer because the input features is of size 0 and hence I get an issue with 1/sqrt(in_features). Would it be possible to do something similar to https://github.com/sai-prasanna/bert-experiments/blob/master/src/model_bert.py where they simply check if the number of heads to prune is equal to the number of heads in that layer and hence take the attentions and dense layer to be None?

Motivation

The motivation for this is that I want my pruning algorithm to be able to prune whole layers if it thinks that this will give the best performance when compressing a model. I imagine that other researchers would appreciate this feature as well.

Your contribution

I am able to take inspiration from Sai-prasanna and add it to the ViT model if you would like. Please let me know.

LysandreJik commented 2 years ago

WDYT @NielsRogge?

@chrisdt1998 do you have an example of how big of a change it would result in the code?

chrisdt1998 commented 2 years ago

Yes, the change would be about 10 lines of code added to the prune_heads method in the ViTAttention class in modeling_vit.py. This could also be extended to other transformer models in the same corresponding functions, for example in modeling_bert.py, the change would be in the prune_heads method in the BertAttention class.

The change for the ViTAttention class would be:

class ViTAttention(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.attention = ViTSelfAttention(config)
        self.output = ViTSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

To:

class ViTAttention(nn.Module):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__()
        self.attention = ViTSelfAttention(config)
        self.output = ViTSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads: Set[int]) -> None:
        if len(heads) == 0:
            return
        if self.attention is None:
            return

        all_pruned = self.pruned_heads.union(heads)
        if len(all_pruned) == self.attention.num_attention_heads:
            self.attention = None
            self.output.dense = None
            # Update hyper params and store pruned heads
            self.pruned_heads = all_pruned
            return

        heads, index = find_pruneable_heads_and_indices(
            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.attention.query = prune_linear_layer(self.attention.query, index)
        self.attention.key = prune_linear_layer(self.attention.key, index)
        self.attention.value = prune_linear_layer(self.attention.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

Please note that all credits go to Sai Prasanna, from this link https://github.com/sai-prasanna/bert-experiments/blob/master/src/model_bert.py and corresponding paper "When BERT Plays the Lottery, All Tickets Are Winning".

Please let me know if you'd like me to clarify further.