VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.63k stars 329 forks source link

Yolov8 pruning with Taylor criteria #373

Open minhhotboy9x opened 5 months ago

minhhotboy9x commented 5 months ago

Hi, I'm trying to experiment with some criteria for yolov8 pruning with Taylor criteria. However, when I set the global pruning is True, the pruner doesn't prune. It only works with the local pruning. Can sb help me with this. Here my summary code:

...
def calculate_grad(model, **args):
    # save temporary model.args
    tmp_args = model.args

    args = get_cfg(DEFAULT_CFG, args['args'])
    device = select_device(device=args.device, batch=0)
    model.args = args
    model.to(device)
    def criterion(preds, batch):
        """Compute loss for YOLO prediction and ground-truth."""
        compute_loss = Loss(de_parallel(model))
        return compute_loss(preds, batch)

    def preprocess_batch(batch, device): #preprocess_batch trong từng ảnh
        """Preprocesses a batch of images by scaling and converting to float."""
        batch['img'] = batch['img'].to(device, non_blocking=True).float() / 255
        return batch

    data = check_det_dataset(args.data)
    trainset = data['train']
    gs = max(int(de_parallel(model).stride.max() if model else 0), 32)
    train_loader = build_dataloader(args, args.batch, img_path=trainset, stride=gs, rank=RANK, mode='train',
                             rect=False, data_info=data)[0]
    nb = len(train_loader)
    print('--------Calculate grad start--------')
    pbar = tqdm(enumerate(train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
    scaler = amp.GradScaler()
    for i, batch in pbar:
        with torch.cuda.amp.autocast():
            batch = preprocess_batch(batch, device)
            preds = model(batch['img'])
            loss, _ = criterion(preds, batch) # cal loss
        scaler.scale(loss).backward()
    print('--------Calculate grad end---------')

    # return back model.args
    model.args = tmp_args
    model.to('cpu')
...
pruner = tp.pruner.MetaPruner(
            model.model,
            example_inputs,
            global_pruning = True, # additional test
            importance=tp.importance.GroupTaylorImportance(), 
            iterative_steps=1,
            ch_sparsity=ch_sparsity,
            ignored_layers=ignored_layers,
            unwrapped_parameters=unwrapped_parameters
        )
calculate_grad(model.model, args=deepcopy(pruning_cfg))
pruner.step() 

Link full code: https://github.com/minhhotboy9x/ultralytics_YOLOv8_custom/blob/main/benchmarks/prunability/yolov8_pruning_taylor.py