horseee / LLM-Pruner

[NeurIPS 2023] LLM-Pruner: On the Structural Pruning of Large Language Models. Support LLaMA, Llama-2, BLOOM, Vicuna, Baichuan, etc.
https://arxiv.org/abs/2305.11627
Apache License 2.0
728 stars 79 forks source link

Checking the pruned but uncompressed model #15

Open ZN1010 opened 11 months ago

ZN1010 commented 11 months ago

Hi,

Thanks a lot for this awesome work! I am wondering whether there is a way to check the pruned but uncompressed model. Now when I save the model, they are already compressed, so I assume those pruned weights are discarded. Any chance that I can locate those pruned weights?

Thanks!

horseee commented 11 months ago

Does that mean that you want to use those weights that have already been pruned?

ZN1010 commented 11 months ago

Does that mean that you want to use those weights that have already been pruned?

Yes. Just want to locate those pruned weights in the original llama architecture (e.g., mask)

horseee commented 11 months ago

You can get the position of pruned parameters by replacing Line 150 in hf_prune.py by:

for group in pruner.step(interactive=True):
    print(group.details())
    group.prune()

And the location of the pruned weights would be displayed:

...
[7] prune_out_channels on _ElementWiseOp_713(MmBackward0) => prune_in_channels on model.layers.12.mlp.down_proj (Linear(in_features=11008, out_features=4096, bias=False)), idxs=[7792, 1264, 6723, 7242, 141, 4435, 7357, 1696, 7402, 5814, 2287, 9878, 2516, 3803, 8753, 9469, 6714, 241, 8318, 7744, 2558, 10709, 497, 9613, 2823, 8817, 10596, 7971, 2692, 6298, 4053, 7371, 22, 4769, 9138, 985, 9628, 10510, 2942, 3412, 6868, 965, 4917, 9369, 4074, 1758, 6263, 5355, 7913, 3102, 29, 8884, 3552, 10449, 6680, 10959, 3522, 3313, 3828, 5642, 5392, 3206, 6321, 8882, 1471, 6638, 1219, 849, 5302, 115, 694, 7469, 6854, 7862, 6358, 1509, 5547, 8821, 7947, 6636, 2420, 441, 8653, 224, 9631, 555, 8006, 5748, 4889, 10072, 2880, 8389, 5753, 7512, 1428, 2075, 4919, 9150, 1115, 4843, 3816, 1194, 8885, 162, 9232, 10681, 6278, 9622, 9483, 8355, 10858, 8997, 8488, 4091, 10049, 2655, 3284, 2591, 5403, 6477, 6022, 1015, 397, 5196, 10029, 1993, 8100, 9032, 2386, 9427, 4194, 6253, 4692, 396, 8426, 5419, 9151, 2114, 2836, 7287, 5565, 519, 8359, 10383, 808, 2265, 4669, 854, 6569, 98, 1221, 9818, 8652, 870, 5763, 4728, 4411, 5883, 610, 9998, 9190, 7839, 10434, 569, 2435, 8388, 9152, 3631, 6070, 5243, 9599, 7760, 4188, 3839, 5282, 6232, 9321, 10821, 10082, 605, 9563, 5442, 6578, 10154, 4123, 5, 6388, 5034, 6639, 8512, 9714, 5493, 1, 2103, 8054, 5703, 9969, 1348, 2427, 3020, 9988, 10064, 6866, 9586, 9861, 10254, 6843, 9827, 972, 206, 6178, 9288, 2387, 1073, 453, 4580, 8589, 10594, 408, 5100, 10023, 9240, 10385, 7373, 10759, 9573, 2615, 544, 4775, 7813, 4243, 5770, 7016, 8537, 1458, 10318, 9282, 8242, 6089, 5227, 6963, 10931, 892, 8193, 7025, 9277, 6266, 5070, 9019, 6007, 8713, 1707, 6845, 7337, 10672, 4588, 7315, 4288, 10158,...
...
ZN1010 commented 11 months ago

You can get the position of pruned parameters by replacing Line 150 in hf_prune.py by:

for group in pruner.step(interactive=True):
    print(group.details())
    group.prune()

And the location of the pruned weights would be displayed:

...
[7] prune_out_channels on _ElementWiseOp_713(MmBackward0) => prune_in_channels on model.layers.12.mlp.down_proj (Linear(in_features=11008, out_features=4096, bias=False)), idxs=[7792, 1264, 6723, 7242, 141, 4435, 7357, 1696, 7402, 5814, 2287, 9878, 2516, 3803, 8753, 9469, 6714, 241, 8318, 7744, 2558, 10709, 497, 9613, 2823, 8817, 10596, 7971, 2692, 6298, 4053, 7371, 22, 4769, 9138, 985, 9628, 10510, 2942, 3412, 6868, 965, 4917, 9369, 4074, 1758, 6263, 5355, 7913, 3102, 29, 8884, 3552, 10449, 6680, 10959, 3522, 3313, 3828, 5642, 5392, 3206, 6321, 8882, 1471, 6638, 1219, 849, 5302, 115, 694, 7469, 6854, 7862, 6358, 1509, 5547, 8821, 7947, 6636, 2420, 441, 8653, 224, 9631, 555, 8006, 5748, 4889, 10072, 2880, 8389, 5753, 7512, 1428, 2075, 4919, 9150, 1115, 4843, 3816, 1194, 8885, 162, 9232, 10681, 6278, 9622, 9483, 8355, 10858, 8997, 8488, 4091, 10049, 2655, 3284, 2591, 5403, 6477, 6022, 1015, 397, 5196, 10029, 1993, 8100, 9032, 2386, 9427, 4194, 6253, 4692, 396, 8426, 5419, 9151, 2114, 2836, 7287, 5565, 519, 8359, 10383, 808, 2265, 4669, 854, 6569, 98, 1221, 9818, 8652, 870, 5763, 4728, 4411, 5883, 610, 9998, 9190, 7839, 10434, 569, 2435, 8388, 9152, 3631, 6070, 5243, 9599, 7760, 4188, 3839, 5282, 6232, 9321, 10821, 10082, 605, 9563, 5442, 6578, 10154, 4123, 5, 6388, 5034, 6639, 8512, 9714, 5493, 1, 2103, 8054, 5703, 9969, 1348, 2427, 3020, 9988, 10064, 6866, 9586, 9861, 10254, 6843, 9827, 972, 206, 6178, 9288, 2387, 1073, 453, 4580, 8589, 10594, 408, 5100, 10023, 9240, 10385, 7373, 10759, 9573, 2615, 544, 4775, 7813, 4243, 5770, 7016, 8537, 1458, 10318, 9282, 8242, 6089, 5227, 6963, 10931, 892, 8193, 7025, 9277, 6266, 5070, 9019, 6007, 8713, 1707, 6845, 7337, 10672, 4588, 7315, 4288, 10158,...
...

That is very helpful! I think I just need some interpretation about the printed details. For example, in "[7]", does it mean that rows with indices such as 7792 and 1264 of the matrix of MLP down projection in layer 12 are pruned? And row index 7792 refers to the 7793-th row of that matrix, right?

Thanks a lot!

horseee commented 11 months ago

Not exactly. There's just a minor detail that needs to be corrected.

Let's take this example: the down_proj Linear layer has in_features=11008 and out_features=4096, which in PyTorch, would create a weight matrix with shape 4096x11008 (a reversed one). Now, if we perform pruning on the in_channels, it means we will prune the columns with indices 7792 (corresponding to the 7793-rd column) and 1264 (corresponding to the 1265-th column).

To get a better understanding of this process, you can refer to the code. In the log, you will find the => prune_in_channels on xxx, and this will lead you to the function prune_in_channels here where the mentioned idxs are used to cut the tensor. Using the example mentioned earlier, the pruning code for the weight matrix is:

layer.weight = torch.nn.Parameter(layer.weight.data[:, keep_idxs])

which specifically targets the columns for pruning

ZN1010 commented 11 months ago

Not exactly. There's just a minor detail that needs to be corrected.

Let's take this example: the down_proj Linear layer has in_features=11008 and out_features=4096, which in PyTorch, would create a weight matrix with shape 4096x11008 (a reversed one). Now, if we perform pruning on the in_channels, it means we will prune the columns with indices 7792 (corresponding to the 7793-rd column) and 1264 (corresponding to the 1265-th column).

To get a better understanding of this process, you can refer to the code. In the log, you will find the => prune_in_channels on xxx, and this will lead you to the function prune_in_channels here where the mentioned idxs are used to cut the tensor. Using the example mentioned earlier, the pruning code for the weight matrix is:

layer.weight = torch.nn.Parameter(layer.weight.data[:, keep_idxs])

which specifically targets the columns for pruning

I guess I only need to care about things after "=>", right?

I tried to print out the pruned weights and notice some "groups" like this "prune_out_channels on _ElementWiseOp_976(SiluBackward0) => prune_out_channels on _ElementWiseOp_975(MulBackward0), idxs=[10068, 2189, 10061, ..]". What is "_ElementWiseOp_975" in LLaMA?

horseee commented 11 months ago

I guess I only need to care about things after "=>", right?

That's correct. The left side of '=>' serves as the trigger, and the pruning process only affects the right side of the '=>'. The right side, after pruning, will also act as the trigger for another operation and you can observe this in the next line of the log.

I tried to print out the pruned weights and notice some "groups" like this "prune_out_channels on _ElementWiseOp_976(SiluBackward0) => prune_out_channels on _ElementWiseOp_975(MulBackward0), idxs=[10068, 2189, 10061, ..]". What is "_ElementWiseOp_975" in LLaMA?

The ElementWiseOps refer to operations that have no weight involved, such as tensor multiplication(torch.matmul), and we do not need to consider pruning for these non-parameter operations. They are used to build the dependency graph of the network. Therefore, you can only focus on the prune_out_channels of the model's weight.

ZN1010 commented 11 months ago

I tried to parse the string generated by group.details(). As a sanity check, I calculated the total number of pruned weights via group.details() and double checked the actual number of pruned weights. Here is what I did:

layer_num, in_f, out_f = all_numbers[0], all_numbers[1], all_numbers[2]
id_list = extract_lists_from_string(split_str[1])[0]

if "prune_out_channels" in split_str[0]:
    total_pruned += len(id_list) * in_f
elif "prune_in_channels" in split_str[0]:
    total_pruned += len(id_list) * out_f

I extracted the layer number, in_features and out_features in all_numbers and id_list refers to the index list. total_pruned just keeps track of total number of pruned weights. I got a significantly lower number than the actual number of pruned weights (I retrieved ~30% of the actual number) under 50% sparsity.

I am wondering if there is anything I missed/ did incorrectly. Here, if I understand correctly, I prune rows for "prune_out_channels" and prune columns for "prune_in_channels". And I only do it if I see mlp or self_attn modules.

horseee commented 11 months ago

Hi. I'm not sure how you set the 50% sparsity. If you set the pruning ratio to 50% in the command, there are several factors that would cause the parameters to prune less than 50%:

Here is an example to prune ~50% parameters for LLaMa-7B:

python hf_prune.py --pruning_ratio 0.6 --device cpu  --eval_device cuda --block_wise --block_mlp_layer_start 3 --block_mlp_layer_end 31 --block_attention_layer_start 3 --block_attention_layer_end 31 --save_ckpt_log_name llama_7B_0.5 --pruner_type taylor --test_after_train --save_model
liuxiaozhu01 commented 2 months ago

You can get the position of pruned parameters by replacing Line 150 in hf_prune.py by:

for group in pruner.step(interactive=True):
    print(group.details())
    group.prune()

And the location of the pruned weights would be displayed:

...
[7] prune_out_channels on _ElementWiseOp_713(MmBackward0) => prune_in_channels on model.layers.12.mlp.down_proj (Linear(in_features=11008, out_features=4096, bias=False)), idxs=[7792, 1264, 6723, 7242, 141, 4435, 7357, 1696, 7402, 5814, 2287, 9878, 2516, 3803, 8753, 9469, 6714, 241, 8318, 7744, 2558, 10709, 497, 9613, 2823, 8817, 10596, 7971, 2692, 6298, 4053, 7371, 22, 4769, 9138, 985, 9628, 10510, 2942, 3412, 6868, 965, 4917, 9369, 4074, 1758, 6263, 5355, 7913, 3102, 29, 8884, 3552, 10449, 6680, 10959, 3522, 3313, 3828, 5642, 5392, 3206, 6321, 8882, 1471, 6638, 1219, 849, 5302, 115, 694, 7469, 6854, 7862, 6358, 1509, 5547, 8821, 7947, 6636, 2420, 441, 8653, 224, 9631, 555, 8006, 5748, 4889, 10072, 2880, 8389, 5753, 7512, 1428, 2075, 4919, 9150, 1115, 4843, 3816, 1194, 8885, 162, 9232, 10681, 6278, 9622, 9483, 8355, 10858, 8997, 8488, 4091, 10049, 2655, 3284, 2591, 5403, 6477, 6022, 1015, 397, 5196, 10029, 1993, 8100, 9032, 2386, 9427, 4194, 6253, 4692, 396, 8426, 5419, 9151, 2114, 2836, 7287, 5565, 519, 8359, 10383, 808, 2265, 4669, 854, 6569, 98, 1221, 9818, 8652, 870, 5763, 4728, 4411, 5883, 610, 9998, 9190, 7839, 10434, 569, 2435, 8388, 9152, 3631, 6070, 5243, 9599, 7760, 4188, 3839, 5282, 6232, 9321, 10821, 10082, 605, 9563, 5442, 6578, 10154, 4123, 5, 6388, 5034, 6639, 8512, 9714, 5493, 1, 2103, 8054, 5703, 9969, 1348, 2427, 3020, 9988, 10064, 6866, 9586, 9861, 10254, 6843, 9827, 972, 206, 6178, 9288, 2387, 1073, 453, 4580, 8589, 10594, 408, 5100, 10023, 9240, 10385, 7373, 10759, 9573, 2615, 544, 4775, 7813, 4243, 5770, 7016, 8537, 1458, 10318, 9282, 8242, 6089, 5227, 6963, 10931, 892, 8193, 7025, 9277, 6266, 5070, 9019, 6007, 8713, 1707, 6845, 7337, 10672, 4588, 7315, 4288, 10158,...
...

@horseee Hi! Sorry for bothering. I've get the pruning mask of self_attn head and mlp intermediate dim by this way, and i also want to check the importance of them. I locate the key part to calculate the importance is here , but i have no idea how to attach them with corresponding group(attn head or mlp intermediate?) Can you help me with these? Thanks!