VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.69k stars 331 forks source link

当我在尝试剪枝YOLOV9时,DG.get_all_groups构建的group中Concat的通道索引不正确 #415

Open EzcodingSen opened 2 months ago

EzcodingSen commented 2 months ago

我的剪枝器设置: example_inputs = torch.randn(1, 3, 640, 640).to(device) ignored_layers = [] unwrapped_parameters = [] importance = tp.importance.GroupNormImportance(p=2) pruner = tp.pruner.MetaPruner( model, example_inputs, importance, iterative_steps=1, pruning_ratio=0.35, ignored_layers=ignored_layers, unwrapped_parameters=unwrapped_parameters, ) 构建完剪枝器,我尝试参照MetaPruner中的_prune函数打印group和每个group的imp进行查看: for group in pruner.DG.get_all_groups(ignored_layers=pruner.ignored_layers, root_module_types=pruner.root_module_types): if pruner._check_pruning_ratio(group):
group = pruner._downstream_node_as_root_if_attention(group) ch_groups = pruner._get_channel_groups(group) print(group) imp = pruner.estimate_importance(group) print(imp)

然后报错了:

      Pruning Group

[0] prune_out_channels on model.9.cv3.1.conv (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => prune_out_channels on model.9.cv3.1.conv (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), len(idxs)=256 [1] prune_out_channels on model.9.cv3.1.conv (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => prune_out_channels on model.9.cv3.1.bn (BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)), len(idxs)=256 [2] prune_out_channels on model.9.cv3.1.bn (BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_21(SiluBackward0), len(idxs)=256 [3] prune_out_channels on _ElementWiseOp_21(SiluBackward0) => prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]), len(idxs)=256 [4] prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]) => prune_in_channels on model.9.cv4.conv (Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)), len(idxs)=256

........ ........ /opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [60,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [61,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [62,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [63,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. Traceback (most recent call last): File "/home/xs/yolov9/tp-prune.py", line 234, in prune(model,save_path,device) File "/home/xs/yolov9/tp-prune.py", line 76, in prune imp = pruner.estimate_importance(group) File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 279, in estimate_importance return self.importance(group) File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 269, in call group_imp = self._reduce(group_imp, group_idxs) File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 149, in _reduce reduced_imp.scatteradd(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

简单概括这个错误是,索引对不上的问题. 我发现在这个group中:[4] prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]) => prune_in_channels on model.9.cv4.conv (Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)), len(idxs)=256 Concat的索引不正确,正常的话应该是[0, 256, 512, 768, 1024],但是在这里是[0, 1024, 2048, 2304, 2560]。

于是我又到GroupNormImportance中去打印索引: @torch.no_grad() def call(self, group: Group): group_imp = [] group_idxs = []

Iterate over all groups and estimate group importance

    for i, (dep, idxs) in enumerate(group):
        layer = dep.layer
        prune_fn = dep.pruning_fn
        root_idxs = group[i].root_idxs
        if not isinstance(layer, tuple(self.target_types)):
            continue

        print(dep)
        print(layer)
        print(root_idxs)
        print(idxs)
        input()

输出: prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]) => prune_in_channels on model.9.cv4.conv (Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)) Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] [2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2375, 2376, 2377, 2378, 2379, 2380, 2381, 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2394, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2411, 2412, 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2455, 2456, 2457, 2458, 2459, 2460, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468, 2469, 2470, 2471, 2472, 2473, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2499, 2500, 2501, 2502, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2511, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519, 2520, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2535, 2536, 2537, 2538, 2539, 2540, 2541, 2542, 2543, 2544, 2545, 2546, 2547, 2548, 2549, 2550, 2551, 2552, 2553, 2554, 2555, 2556, 2557, 2558, 2559] 证明确实是在DG构建时索引不正确。

请问我该如何解决这个问题?

EzcodingSen commented 2 months ago

问题跟进:依赖图构建时跳过了chunk操作导致索引问题

EzcodingSen commented 2 months ago

问题跟进:更换算子中的chunk操作为split.出现新的错误:split出的分支作为多layer输入时,torch_pruning._helpers._SplitIndexMapping对象的offset不正确.

EzcodingSen commented 2 months ago

权宜之计: split作为多layer输入时,超出的idx,都用split划分的最后一部分作为输入 更改 torch-pruning/dependency.py中

                       dep.index_mapping[0] = _helpers._SplitIndexMapping(
                            offset=offsets[i: i + 2], reverse=False
                        )

                    # 如果i超过了可用的offset部分数量,则使用最后一部分
                    if i < num_offsets:
                        dep.index_mapping[0] = _helpers._SplitIndexMapping(
                            offset=offsets[i: i + 2], reverse=False
                        )
                    else:
                        # 超出部分使用最后一部分
                        dep.index_mapping[0] = _helpers._SplitIndexMapping(
                            offset=offsets[-2:], reverse=False
                        )

更换RepNCSPELAN4,和ADown算子中chunk为split,跳层RepNCSPELAN4的cv1和cv4后,跳层CBFuse相关联层(待优化)后,跑通torch-pruning所支持的所有剪枝算法. 能力有限,实属无奈 有更好的处理方法,烦请赐教