kennymckormick / pyskl

A toolbox for skeleton-based action recognition.
Apache License 2.0
940 stars 178 forks source link

How to evaluate Parameters and FLOPs on this models? #165

Open azelee opened 1 year ago

azelee commented 1 year ago

Would you please tell me how to evaluate the Parameters and FLOPs on this model, I dont know how to use the tools on the mmcv files ,many thanks!

adripriadana commented 1 year ago

Hello. Thank you very much for your framework. I also have the same issue. I have already tried many way to calculate the Parameters and FLOPs of the model in this code. However, all of it did not work. So, would you please tell me how to do it? Thank you. Best regards

fangwei814 commented 9 months ago

This is a flops calculation code based on MMEngine, but there are some differences from the flops calculation code in the paper

# Copyright (c) OpenMMLab. All rights reserved.
import argparse

from mmengine import Config
from mmengine.registry import init_default_scope

# from mmaction.registry import MODELS
from pyskl.models import build_model

try:
    from mmengine.analysis import get_model_complexity_info
except ImportError:
    raise ImportError('Please upgrade mmcv to >0.6.2')

def parse_args():
    parser = argparse.ArgumentParser(description='Get model flops and params')
    parser.add_argument('config', help='config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[17, 32, 64, 64],
        help='input image size')
    args = parser.parse_args()
    return args

def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (1, 3) + tuple(args.shape)
    elif len(args.shape) == 4:
        # n, c, h, w = args.shape for 2D recognizer
        input_shape = tuple(args.shape)
    elif len(args.shape) == 5:
        # n, c, t, h, w = args.shape for 3D recognizer or
        # n, m, t, v, c = args.shape for GCN-based recognizer
        input_shape = tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    init_default_scope(cfg.get('default_scope', 'mmaction'))
    model = build_model(cfg.model)
    model.eval()
    model.cuda()

    if hasattr(model, 'extract_feat'):
        model.forward = model.extract_feat
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    analysis_results = get_model_complexity_info(model, input_shape)
    flops = analysis_results['flops_str']
    params = analysis_results['params_str']
    table = analysis_results['out_table']
    print(table)
    split_line = '=' * 30
    print(f'\n{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')

if __name__ == '__main__':
    main()
ahmed-nady commented 3 months ago

Thank you @fangwei814 for sharing this piece of code. Could you tell me how to calculate the FLOPS in the case of rgbposeconv3d model because it has two input as you know.

ahmed-nady commented 2 months ago

I tried to get the FLOPs and #param for this RGBPose model. Here it is Input shape: {'rgb': [1, 3, 8, 224, 224], 'pose': [1, 17, 32, 56, 56]}

Flops: 56.98 GFLOPs Params: 36.15 M

I used the code repository https://github.com/kennymckormick/pyskl used the get_flops script file and made modifications to the flops_counter script file as follows:

rgb_batch = torch.ones(()).new_empty( (1, *[1,3,8,224, 224]), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) pose_batch = torch.ones(()).new_empty( (1, *[1,17,32,56, 56]), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) .... _ = flops_model(rgb_batch,pose_batch)

chenglong521 commented 2 months ago

Can I talk to you? I want to add quantity and flops into the pyskl frame, but I don't know how to add flops. Could you please tell me?