OFA-Sys / Chinese-CLIP

Chinese version of CLIP which achieves Chinese cross-modal retrieval and representation generation.
MIT License
4.21k stars 439 forks source link

转换torchscript报错 #267

Open istoneyou opened 5 months ago

istoneyou commented 5 months ago
with open("./clip_cn_vit-l-14.pt", 'rb') as opened_file:
    checkpoint = torch.load(opened_file, map_location="cpu")
model_arch = "ViT-L-14"
context_length = 52
# prepare the PyTorch implemented model and restore weights
model = create_model(
    _MODEL_INFO[model_arch]['struct'], checkpoint).float().eval()
print(type(model.bert))
text_model = model.bert
text_model.eval()

# Prepare text input
resolution = _MODEL_INFO[model_arch]['input_resolution']
preprocess = image_transform(resolution)
image = preprocess(Image.new('RGB', (resolution, resolution))).unsqueeze(0)
text = clip.tokenize([""], context_length).int()
print(type(text))

# Trace the model for text input
traced_text_model = torch.jit.trace(text_model, text)
torch.jit.save(traced_text_model, './cn_traced_text_model.pt')

image_model = model.visual
image_model.eval()
dummy_image_input = torch.rand(1, 3, 224, 224)

# Trace the model for image input
traced_image_model = torch.jit.trace(image_model, image)
torch.jit.save(traced_image_model, './cn_traced_image_model.pt')

我通过以上代码将pt模型转化为独立的bert和visual torchscript模型,visual模型可以正常转化,但是bert模型转换报错,请问应该怎么转换?转换错误如下:

Traceback (most recent call last):
  File "clip.py", line 29, in <module>
    traced_text_model = torch.jit.trace(text_model, text)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 759, in trace
    return trace_module(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 976, in trace_module
    module._c._create_method_from_trace(
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions