Syencil / mobile-yolov5-pruning-distillation

mobilev2-yolov5s剪枝、蒸馏,支持ncnn,tensorRT部署。ultra-light but better performence!
MIT License
821 stars 163 forks source link

L1 稀疏化训练细节 #33

Closed xuanyuyt closed 3 years ago

xuanyuyt commented 3 years ago
 def compute_pruning_loss(p, prunable_modules, model, loss):
    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
    ll1 = ft([0])
    h = model.hyp  # hyperparameters
    red = 'mean'  # Loss reduction (sum or mean)
    if prunable_modules is not None:
        for m in prunable_modules:
            ll1 += m.weight.norm(1)  # BN 层 gamma 值 的 L1 范数
        ll1 /= len(prunable_modules)
    ll1 *= h['sl']
    bs = p[0].shape[0]  # batch size
    loss += ll1 * bs
    return loss

请教下函数中 平均范数为啥要乘上 batch_size?