openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
24.22k stars 3.17k forks source link

How to count FLOPs during the CLIP inference #143

Closed Akashcodes732 closed 2 years ago

Akashcodes732 commented 2 years ago

I tried already existing FLOP counters for the CLIP model, but they dont seem to work. I need help regarding how to count the FLOPs for an inference on CLIP model.

jongwook commented 2 years ago

We used fvcore's flop_count module with the following modifications to add the operations that it doesn't support out-of-the-box:

import typing
from collections import Counter

from fvcore.nn import flop_count
from fvcore.nn.jit_handles import batchnorm_flop_jit, matmul_flop_jit, generic_activation_jit, get_shape

def generic_pooling_jit(name, multiplier=1):
    def pool_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
        # Inputs[0] contains the shape of the input.
        input_shape = get_shape(inputs[0])
        output_shape = get_shape(outputs[0])
        assert 2 <= len(input_shape) <= 5, input_shape
        flop = prod(input_shape) + prod(output_shape)  # summing all elements + denominating in each for output
        flop_counter = Counter({name: flop * multiplier})
        return flop_counter

    return lambda inputs, outputs: pool_jit(inputs, outputs)

def softmax_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input_shape = get_shape(inputs[0])
    output_shape = get_shape(outputs[0])
    flop = prod(input_shape) * 2 + prod(output_shape) # exponentiating & summing inputs + denominating in each batch
    flop_counter = Counter({"softmax": flop})
    return flop_counter

def bmm_flop_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input1_shape = get_shape(inputs[0])
    input2_shape = get_shape(inputs[1])
    assert len(input1_shape) == len(input2_shape) == 3
    assert input1_shape[0] == input2_shape[0] and input1_shape[2] == input2_shape[1], [input1_shape, input2_shape]
    flop = prod(input1_shape) * input2_shape[-1]  # matmul of bnk * bkm -> bnm; flop = bnkm
    flop_counter = Counter({"bmm": flop})
    return flop_counter

flops, skips = flop_count(
    ForwardWrapper(model),
    inputs=(example_input,),
    supported_ops={
        "aten::batch_norm": batchnorm_flop_jit,
        "aten::group_norm": batchnorm_flop_jit,
        "aten::layer_norm": batchnorm_flop_jit,
        "aten::add": generic_activation_jit("add"),
        "aten::sub": generic_activation_jit("sub"),
        "aten::mul": generic_activation_jit("mul"),
        "aten::div": generic_activation_jit("div"),
        "aten::sqrt": generic_activation_jit("sqrt"),
        "aten::sigmoid": generic_activation_jit("sigmoid"),
        "aten::sigmoid_": generic_activation_jit("sigmoid_"),
        "aten::relu": generic_activation_jit("relu"),
        "aten::relu_": generic_activation_jit("relu_"),
        "aten::gelu": generic_activation_jit("gelu"),
        "aten::add_": generic_activation_jit("add_"),
        "aten::sub_": generic_activation_jit("sub_"),
        "aten::mul_": generic_activation_jit("mul_"),
        "aten::div_": generic_activation_jit("div_"),
        "aten::sqrt_": generic_activation_jit("sqrt_"),
        "aten::adaptive_avg_pool2d": generic_pooling_jit("adaptive_avg_pool2d"),
        "aten::adaptive_max_pool2d": generic_pooling_jit("adaptive_max_pool2d"),
        "aten::avg_pool2d": generic_pooling_jit("avg_pool2d"),
        "aten::max_pool2d": generic_pooling_jit("max_pool2d"),
        "aten::bmm": bmm_flop_jit,
        "aten::mean": generic_pooling_jit("mean"),
        "aten::var": generic_pooling_jit("var", multiplier=3),  # subtracting mean, exponentiate, summing
        "aten::var_mean": generic_pooling_jit("mean_var", multiplier=4),
        "aten::softmax": softmax_jit,
        "aten::dropout": generic_activation_jit("dropout"),
        "aten::frobenius_norm": generic_pooling_jit("frobenius_norm"),
    }
)
sandipan211 commented 2 months ago

Hi @Akashcodes732 ,

Did your issue get solved? I am stuck at the same problem and am in need of urgent help. Kindly help me solving this problem.

sandipan211 commented 2 months ago

We used fvcore's flop_count module with the following modifications to add the operations that it doesn't support out-of-the-box:

import typing
from collections import Counter

from fvcore.nn import flop_count
from fvcore.nn.jit_handles import batchnorm_flop_jit, matmul_flop_jit, generic_activation_jit, get_shape

def generic_pooling_jit(name, multiplier=1):
    def pool_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
        # Inputs[0] contains the shape of the input.
        input_shape = get_shape(inputs[0])
        output_shape = get_shape(outputs[0])
        assert 2 <= len(input_shape) <= 5, input_shape
        flop = prod(input_shape) + prod(output_shape)  # summing all elements + denominating in each for output
        flop_counter = Counter({name: flop * multiplier})
        return flop_counter

    return lambda inputs, outputs: pool_jit(inputs, outputs)

def softmax_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input_shape = get_shape(inputs[0])
    output_shape = get_shape(outputs[0])
    flop = prod(input_shape) * 2 + prod(output_shape) # exponentiating & summing inputs + denominating in each batch
    flop_counter = Counter({"softmax": flop})
    return flop_counter

def bmm_flop_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input1_shape = get_shape(inputs[0])
    input2_shape = get_shape(inputs[1])
    assert len(input1_shape) == len(input2_shape) == 3
    assert input1_shape[0] == input2_shape[0] and input1_shape[2] == input2_shape[1], [input1_shape, input2_shape]
    flop = prod(input1_shape) * input2_shape[-1]  # matmul of bnk * bkm -> bnm; flop = bnkm
    flop_counter = Counter({"bmm": flop})
    return flop_counter

flops, skips = flop_count(
    ForwardWrapper(model),
    inputs=(example_input,),
    supported_ops={
        "aten::batch_norm": batchnorm_flop_jit,
        "aten::group_norm": batchnorm_flop_jit,
        "aten::layer_norm": batchnorm_flop_jit,
        "aten::add": generic_activation_jit("add"),
        "aten::sub": generic_activation_jit("sub"),
        "aten::mul": generic_activation_jit("mul"),
        "aten::div": generic_activation_jit("div"),
        "aten::sqrt": generic_activation_jit("sqrt"),
        "aten::sigmoid": generic_activation_jit("sigmoid"),
        "aten::sigmoid_": generic_activation_jit("sigmoid_"),
        "aten::relu": generic_activation_jit("relu"),
        "aten::relu_": generic_activation_jit("relu_"),
        "aten::gelu": generic_activation_jit("gelu"),
        "aten::add_": generic_activation_jit("add_"),
        "aten::sub_": generic_activation_jit("sub_"),
        "aten::mul_": generic_activation_jit("mul_"),
        "aten::div_": generic_activation_jit("div_"),
        "aten::sqrt_": generic_activation_jit("sqrt_"),
        "aten::adaptive_avg_pool2d": generic_pooling_jit("adaptive_avg_pool2d"),
        "aten::adaptive_max_pool2d": generic_pooling_jit("adaptive_max_pool2d"),
        "aten::avg_pool2d": generic_pooling_jit("avg_pool2d"),
        "aten::max_pool2d": generic_pooling_jit("max_pool2d"),
        "aten::bmm": bmm_flop_jit,
        "aten::mean": generic_pooling_jit("mean"),
        "aten::var": generic_pooling_jit("var", multiplier=3),  # subtracting mean, exponentiate, summing
        "aten::var_mean": generic_pooling_jit("mean_var", multiplier=4),
        "aten::softmax": softmax_jit,
        "aten::dropout": generic_activation_jit("dropout"),
        "aten::frobenius_norm": generic_pooling_jit("frobenius_norm"),
    }
)

Hi @jongwook , I am trying to use your solution but am unable to do so. Where is this ForwardWrapper() defined? Moreover, my model forward() takes 2 inputs - image and its clip preprocessed version. What should be inputs to my flop_count() call?

Kindly help - I am in urgent need of this solution.

X-funbean commented 1 month ago

Hi @sandipan211, have you figured out what is the best practice to count FLOPs of the CLIP model? I have tried several tools on CLIP with ViT-B/16 (e.g. torchsummaryX, thop, and torchinfo), but got different results. Among them, I think the closest result to the FLOPs plotted in the CLIP paper Learning Transferable Visual Models From Natural Language Supervision (figure below) is achieved by torchinfo, which is 14.04GFLOPs (multi-adds). I also tried the codes provided by @jongwook. However, it gave a result of over 161GFLOPs. In addition, according to the model profile log provided by open_clip (https://github.com/mlfoundations/open_clip/blob/main/docs/model_profile.csv), the computation complexity of CLIP with ViT-B/16 should be 41.09 GFLOPs. Any idea on this?

image