Closed aliencaocao closed 5 months ago
Yea you could instantiate the CLIP class directly with name or path of SigLIP model, I think you will need to make your own script for calling it though.
From: Billy Cao @.> Sent: Thursday, June 6, 2024 12:28:29 PM To: dusty-nv/NanoLLM @.> Cc: Subscribed @.***> Subject: [dusty-nv/NanoLLM] Can I convert a siglip only and not a siglip based LLM? (Issue #16)
Based on supported models, conversion of SigLIP to TRT is already done, but can I use it standalone for a SigLIP model only?
— Reply to this email directly, view it on GitHubhttps://github.com/dusty-nv/NanoLLM/issues/16, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ADVEGKY6TL5J5RK2YETRBJLZGCE23AVCNFSM6AAAAABI5CNRNGVHI2DSMVQWIX3LMV43ASLTON2WKOZSGMZTQNRYGI4DCMY. You are receiving this because you are subscribed to this thread.Message ID: @.***>
Do i have to convert siglip to TRT myself, or can nanollm handle the conversion if I supply a pytorch model from HF?
It should handle it internally for you, and will save the trt engine
From: Billy Cao @.> Sent: Thursday, June 6, 2024 2:01:47 PM To: dusty-nv/NanoLLM @.> Cc: Dustin Franklin @.>; Comment @.> Subject: Re: [dusty-nv/NanoLLM] Can I convert a siglip only and not a siglip based LLM? (Issue #16)
Do i have to convert siglip to TRT myself, or can nanollm handle the conversion if I supply a pytorch model from HF?
— Reply to this email directly, view it on GitHubhttps://github.com/dusty-nv/NanoLLM/issues/16#issuecomment-2153104983, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ADVEGK2JSQTU5HOY2GOMED3ZGCPYXAVCNFSM6AAAAABI5CNRNGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNJTGEYDIOJYGM. You are receiving this because you commented.Message ID: @.***>
Thanks!
Sorry another question, where is the code that does the inference on TRT engine? I am trying to write my own inference script but faced some issues. I thought you are experienced with TRT inference so i want to try my luck and see how you implemented it.
This is the script I have for context
class TRTInference:
def __init__(self, engine_path, output_names_mapping: dict = None, fp16=True, verbose=False):
self.engine_path = engine_path
self.output_names_mapping = output_names_mapping or {}
self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
self.engine = None
self.load_engine()
assert self.engine is not None, 'Failed to load TensorRT engine.'
self.context = self.engine.create_execution_context()
self.input_names = self.get_input_names()
self.output_names = self.get_output_names()
self.dtype = np.float16 if fp16 else np.float32
def load_engine(self):
with open(self.engine_path, 'rb') as f, trt.Runtime(self.logger) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
def get_input_names(self):
names = []
for _, name in enumerate(self.engine):
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
names.append(name)
return names
def get_output_names(self):
names = []
for _, name in enumerate(self.engine):
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
names.append(name)
return names
def get_bindings(self):
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
bindings = OrderedDict()
for i, name in enumerate(self.engine):
shape = self.engine.get_tensor_shape(name)
shape = tuple(shape)
if any(s < 0 for s in shape): # set dynamic axis to be 1
shape = tuple(1 if s < 0 else s for s in shape)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
data = np.random.randn(*shape).astype(dtype)
ptr = cuda.mem_alloc(data.nbytes)
bindings[name] = Binding(name, dtype, shape, data, ptr)
else:
data = cuda.pagelocked_empty(trt.volume(shape), dtype)
ptr = cuda.mem_alloc(data.nbytes)
bindings[name] = Binding(name, dtype, shape, data, ptr)
return bindings
def __call__(self, blob):
self.stream = cuda.Stream()
input_bs = next(v for v in blob.values()).shape[0]
blob = {n: np.ascontiguousarray(v) for n, v in blob.items()}
for n in self.input_names:
input_shape = self.engine.get_tensor_shape(n)
input_shape = (input_bs, *input_shape[1:])
self.context.set_input_shape(n, input_shape)
d_input = cuda.mem_alloc(np.random.randn(*input_shape).astype(self.dtype).nbytes)
self.context.set_tensor_address(n, int(d_input))
cuda.memcpy_htod_async(d_input, blob[n], self.stream)
output_shape_list = []
for n in self.output_names:
output_shape = self.engine.get_tensor_shape(n)
output_shape = (input_bs, *output_shape[1:])
d_output = int(cuda.mem_alloc(np.random.randn(*output_shape).astype(self.dtype).nbytes))
self.context.set_tensor_address(n, d_output)
output_shape_list.append(output_shape)
assert self.context.all_binding_shapes_specified
self.context.execute_async_v3(stream_handle=self.stream.handle)
outputs = {}
for n, output_shape in zip(self.output_names, output_shape_list):
output = np.empty(output_shape, dtype=self.dtype)
cuda.memcpy_dtoh_async(output, self.context.get_tensor_address(n), self.stream)
outputs[self.output_names_mapping.get(n, n)] = output
self.stream.synchronize()
return outputs
@aliencaocao I am using torch2trt from @jaybdub which gives you a model object with the same interface as pytorch so you can just transparently replace your pytorch model with the TRT version. Which is what I do here: https://github.com/dusty-nv/NanoLLM/blob/b0be32710f0faf7983e63c98b433e6203c061b57/nano_llm/vision/clip.py#L176
If I wanted to make something that didn't depend on pytorch, then yea I would use the TRT API directly, or onnxruntime if I wanted to be able to fallback to cuDNN.
Thanks for pointing out. Do you know how can I load a .engine directly using torch2trt? I have exported it separately to use my own shape configs.
Actually i managed to port it over but im getting some
[06/08/2024-12:35:15] [TRT] [E] 1: [convBaseRunner.cpp::nvinfer1::rt::task::CaskConvBaseRunner::execute::300] Error Code 1: Cask (Cask convolution execution)
This is probably not related but have you seen this before?
@aliencaocao not specifically as related to this (being CLIP/SigLIP)
Thank you. I will try to convert using torch2trt myself and see how.
I tried to use nanoLLM but it is missing the text model of siglip which I also need, else I would have used it straight away.
Thank you for your help!
OK thanks, let me know if you find that TRT or torch2trt can build/run the SigLIP text encoder, that would be good for me to add too.
I got the exporting for vision model to work, but not the text part. Weird thing is i can run inference on the pytorch model as normal. The error stated means I am missing the first argument to the forward method which makes no sense since i did give an input...
File C:\Program Files\Python39\lib\site-packages\transformers\models\siglip\modeling_siglip.py:774, in SiglipTextTransformer.forward(self, input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict)
771 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
773 if input_ids is None:
--> 774 raise ValueError("You have to specify input_ids")
776 input_shape = input_ids.size()
777 input_ids = input_ids.view(-1, input_shape[-1])
ValueError: You have to specify input_ids
will see if there's another way before I get into patching HF transformers
Edit: well i think its a bug with torch2trt because it made 2 forward calls to the model, first with the inputs and 2nd time with no args. See https://github.com/NVIDIA-AI-IOT/torch2trt/issues/931
Got it to work after https://github.com/NVIDIA-AI-IOT/torch2trt/pull/932
Full conversion script, requires pip install git+https://github.com/aliencaocao/torch2trt.git@patch-1
import torch
from torch2trt import torch2trt
from transformers import SiglipModel
model = SiglipModel.from_pretrained('siglip ckpt path/HF id', torch_dtype=torch.float16).cuda().eval()
text_model = model.text_model
vision_model = model.vision_model
dummy = torch.ones(1, 3, 384, 384, dtype=torch.float16, device='cuda')
model_trt = torch2trt(vision_model, [dummy], fp16_mode=True, min_shapes=[(1, 3, 384, 384)], opt_shapes=[(4, 3, 384, 384)], max_shapes=[(10, 3, 384, 384)], use_onnx=True) # change the shapes here - for me I know max bs is 10 but you should probably not limit
y = vision_model(dummy).pooler_output
y_trt = model_trt(dummy)['pooler_output']
torch.save(model_trt.state_dict(), 'vision_trt.pth')
print('Vision model exported. atol:', torch.max(torch.abs(y - y_trt)))
dummy = torch.ones(1, 64, dtype=torch.long, device='cuda') # siglip tokenizer should always pad to 64
model_trt = torch2trt(text_model, [dummy], fp16_mode=True, min_shapes=[(1, 64)], opt_shapes=[(1, 64)], max_shapes=[(1, 64)], use_onnx=True)
y = text_model(dummy).pooler_output
y_trt = model_trt(dummy)['pooler_output']
torch.save(model_trt.state_dict(), 'text_trt.pth')
print('Text model exported. atol:', torch.max(torch.abs(y - y_trt)))
Remove all the torch.float16 if you want to be in fp32
Then can just load via TRTModule.load_state_dict(torch.load(os.path.join(clip_path, 'vision_trt.pth')))
To get the logits:
vision_input = self.clip_image_processor(images=boxes, return_tensors='pt').to(self.device)
text_inputs = self.clip_tokenizer(im_captions_templated, return_tensors='pt', padding='max_length', truncation=True).to(self.device) # processor wont work since it dont pad to max_length=64
vision_input = vision_input['pixel_values'].type(torch.float16)
image_feat = self.clip_vision_trt(vision_input)['pooler_output']
text_feat = self.clip_text_trt(text_inputs['input_ids'])['pooler_output']
image_feat /= image_feat.norm(p=2, dim=-1, keepdim=True)
text_feat /= text_feat.norm(p=2, dim=-1, keepdim=True)
scores = image_feat @ text_feat.T * self.clip_logit_scale_exp + self.clip_logit_bias
scores = scores.squeeze(-1).tolist() # sigmoid not needed as it dont change the ranking
Thanks @aliencaocao , that's great! I'm going to unify the various CLIP/SigLIP implementations I have floating around between NanoLLM/NanoDB with support for the text encoder in TRT alongside the vision encoder 👍
@aliencaocao did you get the text encoder working in TRT with real token ID's? The output delta is small when the input_id's are all 1, but when I actually tokenize a real string it doesn't work. Which version of TensorRT are you using?
edit: I also tried using the attention_mask from the tokenizer
TRT 10.1. Yes all tokenid, i am using it in on over 20k samples already
One very important thing is do not use HF Processor but instead HF tokenzier. The processor does not pad input to 64 ('max-length') token which is what siglip has been trained on. also double check that padding token id is 1. This differs from other CLIPs where the padding would be the max length in a batch and not always 64/whatever context length.
Also note that I hard coded batchsize of text to be 1, and you may have to change it to a more dynamic one.
Attenmask is not needed.
Also make sure you are using the right logit_scale_exp and logit_bias from the original HF model, they changes for fine-tuned models. And don't forget to .exp()
on model.logit_scale
. I usually just precalculates it since it is not exported to TRT
And to get the same output as HF pipeline (python/js), you need to add a torch.sigmoid(scores) before the last line. For purpose of image reranking that's not needed since sigmoid wont change the order.
Based on supported models, conversion of SigLIP to TRT is already done, but can I use it standalone for a SigLIP model only?