Open darwinharianto opened 1 year ago
import torch import numpy as np from transformers import AutoModel, AutoProcessor, OwlViTModel, OwlViTProcessor, CLIPTokenizerFast, CLIPTokenizer import coremltools as ct from PIL import Image import requests class MyOpenDetector(torch.nn.Module): def __init__(self, model=None): super(MyOpenDetector, self).__init__() self.model = model def forward(self, input_ids, pixel_values, attention_mask): # inputs = {"input_ids":x[0], "attention_mask":x[1], "pixel_values":x[2]} outputs = self.model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) logits_per_image = outputs[0] # this is the image-text similarity score probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities return probs def save_owlvitmodel(inputs, modelname): openModel = AutoModel.from_pretrained(modelname, torchscript=True).eval() x = tuple([inputs['input_ids'], inputs['pixel_values'], inputs['attention_mask']]) model = MyOpenDetector(model=openModel) traced_model = torch.jit.trace(model, x) torch.jit.save(traced_model, 'traced_owlvit.pt') return traced_model modelname = "google/owlvit-base-patch32" processor = AutoProcessor.from_pretrained(modelname, torchscript=True) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=torch.Tensor(np.asarray(image)), return_tensors="pt") traced_model = save_owlvitmodel(inputs, modelname) loaded_model = torch.jit.load("traced_owlvit.pt") loaded_model.eval() x = tuple([inputs['input_ids'], inputs['pixel_values'], inputs['attention_mask']]) probs = loaded_model(*x) mlmodel = ct.convert( traced_model, inputs=[ct.TensorType(name="input_ids", shape=(ct.RangeDim(1, 16),16), dtype=np.int32), ct.TensorType(name="pixel_values", shape=(ct.RangeDim(1, 3),3,768,768), dtype=np.float32), ct.TensorType(name="attention_mask", shape=(ct.RangeDim(1, 16),16), dtype=np.int32),], ) mlmodel.save('coremlmodel_owlvit.mlmodel')
It also fails without using flexible shape (i.e. not using ct.RangeDim, just using the shapes in x). Although the error is different.
ct.RangeDim
x
System environment (please complete the following information):