Closed zeo233 closed 5 months ago
目前已经支持了,详细情况参考 #326 ,此外还有个问题需要注意一下:
pruner = OneShotChannelPruner(model, torch.ones(1, 80, 3000), args.config)
.prune
pruner.generate_config(args.config)
pruner.prune()
for idx_l, layer in enumerate(model.audio_encoder.model.layers): layer.self_attn.num_heads = layer.self_attn.q_proj.out_features // layer.self_attn.head_dim layer.self_attn.embed_dim = layer.self_attn.num_heads * layer.self_attn.head_dim
model(torch.ones(1, 80, 3000)) exit()
目前已经支持了,详细情况参考 #326 ,此外还有个问题需要注意一下:
- transformer类的模型,例如你这边使用的Whisper,模型定义大部分是通过config在init时就固定了,剪枝后模型的部分维度发生变化,你需要手动改一下这部分维度的定义。例如在你提供的WhisperAttention中进行剪枝,需要减少num_heads和embed_dim,剪枝后需要手动进行修改。
pruner = OneShotChannelPruner(model, torch.ones(1, 80, 3000), args.config) # (Optional) A new config file with layer-level sparsity will be generated inplace # If you want to customize those generated content, you may do that before calling `.prune` pruner.generate_config(args.config) # Get the pruned untrained model pruner.prune() # Manually modify related predefined dimension variables for idx_l, layer in enumerate(model.audio_encoder.model.layers): layer.self_attn.num_heads = layer.self_attn.q_proj.out_features // layer.self_attn.head_dim layer.self_attn.embed_dim = layer.self_attn.num_heads * layer.self_attn.head_dim model(torch.ones(1, 80, 3000)) exit()
确实能跑通了,谢谢!请问transformer剪枝的时候剪掉的是注意力头吗
一般三个维度剪,Attention_head,MLP_intermediate_size,hidden_size。目前是屏蔽了hidden_size维度,剪前两者。你可以把模型print看一下对应维度的减少。
一般三个维度剪,Attention_head,MLP_intermediate_size,hidden_size。目前是屏蔽了hidden_size维度,剪前两者。你可以把模型print看一下对应维度的减少。
明白了,多谢多谢
求助! 将examples/pruner/oneshot/oneshot_prune.py里的模型替换为自己的模型时报错,代码如下: oneshot_prune_wt.zip 报错如下: Running on cuda:0
INFO (tinynn.graph.modifier) Start tracking tensor dimension changes...
ERROR (tinynn.graph.modifier) error modifier = bmm_0_f, type = <class 'tinynn.graph.tracer.TraceFunction'>, kind = bmm
Traceback (most recent call last):
File "oneshot_prune_wt.py", line 107, in
main_worker(args)
File "oneshot_prune_wt.py", line 58, in main_worker
pruner = OneShotChannelPruner(model, torch.ones(1, 80, 3000), args.config)
File "oneshot_pruner.py", line 71, in init
self.graph_modifier = modifier.GraphChannelModifier(self.graph, self.center_nodes, self.bn_compensation)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 3315, in init
self.sub_graphs = SubGraphDivider(self.graph, self.modifiers).divide()
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 3280, in divide
self.change_dimension()
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 3227, in change_dimension
dim_changed = m.change_dimension()
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 1745, in change_dimension
m.dim_change_forward(self, self.next_tensors()[0], dim_changes_o, None, tensor_constraint)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 867, in dim_change_forward
m.dim_change_forward(center, tensor_o, dim_changes_i, dim_transform, tensor_constraint)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 1259, in dim_change_forward
m.dim_change_forward(center, tensor_o, dim_change_o, dim_transform, None)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 1259, in dim_change_forward
m.dim_change_forward(center, tensor_o, dim_change_o, dim_transform, None)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 867, in dim_change_forward
m.dim_change_forward(center, tensor_o, dim_changes_i, dim_transform, tensor_constraint)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 1259, in dim_change_forward
m.dim_change_forward(center, tensor_o, dim_change_o, dim_transform, None)
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 859, in dim_change_forward
raise e
File "/conda/envs/qt/lib/python3.10/site-packages/tinynn/graph/modifier.py", line 853, in dim_change_forward
tensoro.copy(tensor.clone())
RuntimeError: The size of tensor a (1500) must match the size of tensor b (64) at non-singleton dimension 2