Lednik7 / CLIP-ONNX

It is a simple library to speed up CLIP inference up to 3x (K80 GPU)
MIT License
193 stars 25 forks source link

Matching OpenAI CLIP output by adding attention mask #14

Closed blahBlahhhJ closed 1 year ago

blahBlahhhJ commented 1 year ago

Fixes #3

Thanks for the great work Gerasimov!

I'm sure you've already noticed that your monkey-patched model doesn't perform the same as OpenAI's original implementation does. When using your model, I observed that it usually outputs a uniform distribution (e.g. 50% vs 50%) of the labels, while OpenAI's original model will output a very skewed and accurate distribution (e.g. 99% vs 1%).

After investigating the issue, I found that the reason of the performance difference is because you forgot to use attn_mask in the transformer for the text encoder. This fixes the performance issue, and you could verify it by running the forward pass of the PyTorch model.

After adding the attention mask back, although the torch model works fine, the onnx conversion will throw an error with something like "found tensors on different devices (cpu and cuda)". I then found that this is due to an issue with PyTorch's kernel implementation of the attention mechanism, making it impossible to do onnx conversion when you use an attention mask (since you didn't use attn_mask, it doesn't throw an error to you before). Therefore, I rewrote the attention function using native PyTorch (instead of calling the F.scaled_dot_product_attention which calls the underlying C++ attention kernel). Note that this may slow things down a little bit?

These 2 fixes should be sufficient for this repo to not only reach target performance but also remain convertible to onnx and thus accelerating inference. It is recommended to further convert it into TensorRT format to achieve even more speedup.

By the way, a better way to monkey patch is to do a.f = f_patch.__get__(a, A) instead of a.f = f_patch. The latter approach (which is your current approach) will throw an error if you directly call a.f() after patching. But for now I think it's fine cuz you're not directly calling attention().