microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14k stars 1.81k forks source link

Can't speed up model when pruning mT5 model #5333

Open Kathrine94 opened 1 year ago

Kathrine94 commented 1 year ago

Describe the issue: I use TaylorFOWeightPruner to prune mT5_base model, but the errors happened when speed up model. pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_prunersteps) , ffn_masks = pruner.compress() renamed_ffn_masks = {}

rename the masks keys, because we only speedup the bert.encoder

            for model_name, targets_mask in ffn_masks.items():
                renamed_ffn_masks[model_name] = targets_mask
            pruner._unwrap_model()
            attention_pruned_model.load_state_dict(check_point)
            m_Speedup = ModelSpeedup(attention_pruned_model, (a.to(device), b.to(device), c.to(device), d.to(device)), renamed_ffn_masks)
            m_Speedup.speedup_model()
            optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)

logg error:

[2023-02-03 11:29:50] start to speedup the model 2023-02-03 11:29:50 - INFO: start to speedup the model 2023-02-03 11:29:54 - INFO: {} 2023-02-03 11:29:54 - WARNING: no multi-dimension masks found. 2023-02-03 11:29:54 - INFO: Dectected conv prune dim" 0 [2023-02-03 11:29:55] infer module masks... 2023-02-03 11:29:55 - INFO: infer module masks... [2023-02-03 11:29:55] Update mask for encoder.aten::size.519 2023-02-03 11:29:55 - INFO: Update mask for encoder.aten::size.519 [2023-02-03 11:29:55] Update mask for encoder.aten::slice.521 2023-02-03 11:29:55 - INFO: Update mask for encoder.aten::slice.521 [2023-02-03 11:29:55] Slice dim:0, Slice obj:slice(0, 9223372036854775807, 1) 2023-02-03 11:29:55 - INFO: Slice dim:0, Slice obj:slice(0, 9223372036854775807, 1) [2023-02-03 11:29:55] Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) 2023-02-03 11:29:55 - INFO: Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) [2023-02-03 11:29:55] Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) 2023-02-03 11:29:55 - INFO: Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) [2023-02-03 11:29:55] Update mask for decoder.aten::size.808 2023-02-03 11:29:55 - INFO: Update mask for decoder.aten::size.808 [2023-02-03 11:29:55] Update mask for decoder.aten::size.809 2023-02-03 11:29:55 - INFO: Update mask for decoder.aten::size.809 [2023-02-03 11:29:55] Update mask for decoder.aten::slice.825 2023-02-03 11:29:55 - INFO: Update mask for decoder.aten::slice.825 [2023-02-03 11:29:55] Slice dim:0, Slice obj:slice(0, 9223372036854775807, 1) 2023-02-03 11:29:55 - INFO: Slice dim:0, Slice obj:slice(0, 9223372036854775807, 1) [2023-02-03 11:29:55] Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) 2023-02-03 11:29:55 - INFO: Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) [2023-02-03 11:29:55] Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) 2023-02-03 11:29:55 - INFO: Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) [2023-02-03 11:29:55] Update mask for decoder.aten::slice.833 2023-02-03 11:29:55 - INFO: Update mask for decoder.aten::slice.833 [2023-02-03 11:29:55] Slice dim:0, Slice obj:slice(0, 9223372036854775807, 1) 2023-02-03 11:29:55 - INFO: Slice dim:0, Slice obj:slice(0, 9223372036854775807, 1) [2023-02-03 11:29:55] Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) 2023-02-03 11:29:55 - INFO: Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) [2023-02-03 11:29:55] Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) 2023-02-03 11:29:55 - INFO: Model has Slice operation, and the operand size=torch.Size([8, 10]), Slice object:(slice(0, 9223372036854775807, 1),) [2023-02-03 11:29:55] Update mask for encoder.aten::view.520 2023-02-03 11:29:55 - INFO: Update mask for encoder.aten::view.520 [2023-02-03 11:29:55] WARNING: throw some args away when calling the function "view" 2023-02-03 11:29:55 - WARNING: throw some args away when calling the function "view" [2023-02-03 11:29:55] WARNING: throw some args away when calling the function "view" 2023-02-03 11:29:55 - WARNING: throw some args away when calling the function "view" [2023-02-03 11:29:55] Update mask for encoder.aten::unsqueeze.522 2023-02-03 11:29:55 - INFO: Update mask for encoder.aten::unsqueeze.522 [2023-02-03 11:29:55] Update mask for decoder.aten::view.810 2023-02-03 11:29:55 - INFO: Update mask for decoder.aten::view.810 [2023-02-03 11:29:55] WARNING: throw some args away when calling the function "view" 2023-02-03 11:29:55 - WARNING: throw some args away when calling the function "view" [2023-02-03 11:29:55] WARNING: throw some args away when calling the function "view" 2023-02-03 11:29:55 - WARNING: throw some args away when calling the function "view" [2023-02-03 11:29:55] Update mask for decoder.aten::arange.811 2023-02-03 11:29:55 - INFO: Update mask for decoder.aten::arange.811 Traceback (most recent call last): File "attention_pruned_model.py", line 446, in main(config) File "attention_pruned_model.py", line 387, in main m_Speedup.speedup_model() File "/home/dhnan/anaconda3/envs/nemo/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py", line 546, in speedup_model self.infer_modules_masks() File "/home/dhnan/anaconda3/envs/nemo/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py", line 383, in infer_modules_masks self.update_direct_sparsity(curnode) File "/home/dhnan/anaconda3/envs/nemo/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py", line 237, in update_direct_sparsity _auto_infer = AutoMaskInference( File "/home/dhnan/anaconda3/envs/nemo/lib/python3.8/site-packages/nni/compression/pytorch/speedup/infer_mask.py", line 80, in init self.output = self.module(dummy_input) File "/home/dhnan/anaconda3/envs/nemo/lib/python3.8/site-packages/nni/compression/pytorch/speedup/jit_translate.py", line 245, in call result = self.func(self.positional, **self.keyword) TypeError: arange() received an invalid combination of arguments - got (Tensor, pin_memory=bool, device=torch.device, layout=NoneType, dtype=NoneType), but expected one of:

Could you please tell me how to solve it? Thanks very much!

Environment:

Configuration:

Log message:

How to reproduce it?:

super-dainiu commented 1 year ago

Hello, would u please provide the details of ur model?

Kathrine94 commented 1 year ago

Hello, would u please provide the details of ur model? This is my script,and I have modified transformer code so that the output of mT5 is tuple. The error occurred at line 389. script.txt

Kathrine94 commented 1 year ago

@super-dainiu

super-dainiu commented 1 year ago

I have been working on this issue recently. However, I encountered another bug during ModelSpeedup(). Let me reply to you as soon as possible!