VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.44k stars 308 forks source link

使用flash_attn 优化后的 vit进行剪枝,出现尺寸不匹配的问题。 #357

Open rikeLiu opened 3 months ago

rikeLiu commented 3 months ago

/usr/local/miniconda3/lib/python3.8/site-packages/torch_pruning/dependency.py:667: UserWarning: Unwrapped parameters detected: ['cls_token', 'pos_embed']. Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument. warnings.warn(warning_str) Traceback (most recent call last): File "/usr/local/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 523, in reduce return _apply_recipe( File "/usr/local/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 234, in _apply_recipe init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( File "/usr/local/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 187, in _reconstruct_from_shape_uncached raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}") einops.EinopsError: Shape mismatch, can't divide axis of length 2188 in chunks of 192

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "flash-attention-2.4.2/flash_attn/models/vit3.py", line 466, in pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch_pruning/utils/op_counter.py", line 35, in count_ops_andparams = flops_model(example_inputs) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(*args, *kwargs) File "flash-attention-2.4.2/flash_attn/models/vit3.py", line 323, in forward x = self.forward_features(x, all_tokens=False) File "flash-attention-2.4.2/flash_attn/models/vit3.py", line 284, in forward_features hidden_states, residual = block(hidden_states, residual) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/flash_attn/modules/block.py", line 178, in forward hidden_states = self.mixer(hidden_states, *mixer_kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/flash_attn/modules/mha.py", line 622, in forward qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) File "/usr/local/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 591, in rearrange return reduce(tensor, pattern, reduction="rearrange", axes_lengths) File "/usr/local/miniconda3/lib/python3.8/site-packages/einops/einops.py", line 533, in reduce raise EinopsError(message + "\n {}".format(e)) einops.EinopsError: Error while processing rearrange-reduction pattern "... (three h d) -> ... three h d". Input tensor shape: torch.Size([1, 197, 2188]). Additional info: {'three': 3, 'd': 64}. Shape mismatch, can't divide axis of length 2188 in chunks of 192