Closed chrisdt1998 closed 2 years ago
WDYT @NielsRogge?
@chrisdt1998 do you have an example of how big of a change it would result in the code?
Yes, the change would be about 10 lines of code added to the prune_heads method in the ViTAttention class in modeling_vit.py. This could also be extended to other transformer models in the same corresponding functions, for example in modeling_bert.py, the change would be in the prune_heads method in the BertAttention class.
The change for the ViTAttention class would be:
class ViTAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
To:
class ViTAttention(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
self.attention = ViTSelfAttention(config)
self.output = ViTSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
if self.attention is None:
return
all_pruned = self.pruned_heads.union(heads)
if len(all_pruned) == self.attention.num_attention_heads:
self.attention = None
self.output.dense = None
# Update hyper params and store pruned heads
self.pruned_heads = all_pruned
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
Please note that all credits go to Sai Prasanna, from this link https://github.com/sai-prasanna/bert-experiments/blob/master/src/model_bert.py and corresponding paper "When BERT Plays the Lottery, All Tickets Are Winning".
Please let me know if you'd like me to clarify further.
Feature request
Dear HuggingFace team,
In the ViT Model folder (namely modeling_vit.py), there is an option to prune the attention heads of a model. However, at the moment, if I want to prune a whole layer, I get an error due to the dense layer because the input features is of size 0 and hence I get an issue with 1/sqrt(in_features). Would it be possible to do something similar to https://github.com/sai-prasanna/bert-experiments/blob/master/src/model_bert.py where they simply check if the number of heads to prune is equal to the number of heads in that layer and hence take the attentions and dense layer to be None?
Motivation
The motivation for this is that I want my pruning algorithm to be able to prune whole layers if it thinks that this will give the best performance when compressing a model. I imagine that other researchers would appreciate this feature as well.
Your contribution
I am able to take inspiration from Sai-prasanna and add it to the ViT model if you would like. Please let me know.