jhb86253817 / PromptMRG

MIT License
50 stars 6 forks source link

请问有没有用CLIP提取文本特征的代码? #7

Open afasijbfk opened 6 months ago

afasijbfk commented 6 months ago

您好,我想使用CLIP来处理另一个数据集的文本特征,请问有没有用CLIP提取文本特征的代码?

afasijbfk commented 6 months ago

或者您用的什么CLIP来提取的呢,谢谢

jhb86253817 commented 6 months ago

你好,我用的是CLIP官方提供的API (https://github.com/openai/CLIP) 以下是一段简单的示例代码供你参考。

import os
import json
import torch
import clip

pretrained_path = 'clip-imp-pretrained_128_6_after_4.pt'

with open('xxx.json', 'r') as f:
    annos = json.load(f)

device = "cuda:0" if torch.cuda.is_available() else "cpu" 

# Load pre-trained CLIP model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

state_dict = torch.load(pretrained_path, map_location="cuda:0")
model.load_state_dict(state_dict)
print("load checkpoint from {}".format(pretrained_path))

texts = []
for anno in annos:
    text = anno['text']
    texts.append(text)

with torch.no_grad():
    texts= clip.tokenize(texts).to(device)
    text_features = model.encode_text(texts)
    text_features /= text_features.norm(dim=-1, keepdim=True)