mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.93k stars 959 forks source link

Triton error in int8-support #835

Closed aleablu closed 6 months ago

aleablu commented 7 months ago

I am following the int8 tutorial at https://github.com/mlfoundations/open_clip?tab=readme-ov-file#int8-support but I cannot make it work with the latest version of open clip.

Installing the required triton version (2.0.0.post1) forces pip to also install torch 1.13 which breaks the code that loads open_clip, because I'm using the latest version of open_clip_torch (2.24.0).

I tried installing the latest triton version (2.1.0) but the code fails with:

AttributeError: module 'triton.language' has no attribute 'libdevice'

Is there a way to solve this without reverting to a version of open_clip that supported torch 1.13?

Here is the code I'm using for reference, even though it's a mere copy-paste of the notebook:

import open_clip
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, _,preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained='frozen_laion5b_s13b_b90k')

tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14')

# This replaces linear layers with int8_linear_layer
import bitsandbytes as bnb
model = model.cpu()
int8_linear_layer = bnb.nn.triton_based_modules.SwitchBackLinear
# replace linear layers, for now just replace FFN - more coming in later PR
int8_model = open_clip.utils.replace_linear(model, int8_linear_layer, include_modules=['c_fc', 'c_proj']).cuda()
# may take a sec to run this because it has to compile the kernels.
int8_model.set_grad_checkpointing()
int8_model.eval()
# If you just care about inference you can make things go faster by precomputing the int8 quantized
# weights and deleting the original weights.
# This is what you should do if you're not training as it's also much less memory.
# prepare for eval by deleting the original weights and storing the quantized version of the weights
from open_clip.utils import convert_int8_model_to_inference_mode
convert_int8_model_to_inference_mode(int8_model)
rwightman commented 6 months ago

@aleablu that 'was' a necessary install when it was added, I'm sure newer triton installs work just fine now and wouldn't cause issues with newer torch. Try using the default tritons in latest torch versions...

EdenChen233 commented 2 months ago

I'm using newer torch==2.4.0 , newer triton==3.0.0 and open_clip_torch==2.26.1 in python 3.8. And I have the same problem " AttributeError: module 'triton.language' has no attribute 'libdevice' " when I try to run convert_int8_model_to_inference_mode. Could you please tell me more details about the environment in int8-support notebook?