open-mmlab / mmaction2

OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark
https://mmaction2.readthedocs.io
Apache License 2.0
4.04k stars 1.2k forks source link

[Feature] How to get flops of the MMRecognizer3D? #2794

Open wukurua opened 4 months ago

wukurua commented 4 months ago

What is the problem this feature will solve?

I want to test GFlops for RGBPoseConv3D (configs\skeleton\posec3d\rgbpose_conv3d\rgbpose_conv3d.py), but the current code doesn't seem to support MMRecognizer3D . Hope to have a section for testing the Multi-modal 3D recognizer model framework in tools/analysis_tools/get_flops.py. By the way, If you know the params and GFlops of RGBPoseConv3D, or its input_shape, can anyone tell me about it?

def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (1, 3) + tuple(args.shape)
    elif len(args.shape) == 4:
        # n, c, h, w = args.shape for 2D recognizer
        input_shape = tuple(args.shape)
    elif len(args.shape) == 5:
        # n, c, t, h, w = args.shape for 3D recognizer or
        # n, m, t, v, c = args.shape for GCN-based recognizer
        # n, v, t, m, c = args.shape for GCN-based recognizer
        input_shape = tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

What is the feature?

get flops of the MMRecognizer3D

What alternatives have you considered?

No response

Yangxinyee commented 1 month ago

the same question here...

ahmed-nady commented 5 days ago

I tried to get the FLOPs and #param for this RGBPose model. Here it is Input shape: {'rgb': [1, 3, 8, 224, 224], 'pose': [1, 17, 32, 56, 56]} Flops: 56.98 GFLOPs Params: 36.15 M

I used the code repository https://github.com/kennymckormick/pyskl and use the get_flops script and made modification in flops_counter script file as follows:

rgb_batch = torch.ones(()).new_empty( (1, *[1,3,8,224, 224]), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) pose_batch = torch.ones(()).new_empty( (1, *[1,17,32,56, 56]), dtype=next(flops_model.parameters()).dtype, device=next(flops_model.parameters()).device) .... _ = flops_model(rgb_batch,pose_batch)