OFA-Sys / Chinese-CLIP

Chinese version of CLIP which achieves Chinese cross-modal retrieval and representation generation.
MIT License
4.48k stars 462 forks source link

加载训练好的模型的问题 #63

Closed JeffMony closed 1 year ago

JeffMony commented 1 year ago

你好,我请教一下,我训练好了新的模型: image

请问我怎么应用这个模型啊,例如项目中给了一个例子:

import torch 
from PIL import Image

import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models())  
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
model.eval()
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    # 对特征进行归一化,请使用归一化后的图文特征用于下游任务
    image_features /= image_features.norm(dim=-1, keepdim=True) 
    text_features /= text_features.norm(dim=-1, keepdim=True)    

    logits_per_image, logits_per_text = model.get_similarity(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]

怎么让这儿加载的是我训练好的模型啊?多多指教,感谢。

DtYXs commented 1 year ago

您好,假设您训练的是Base规模的模型,您可以尝试将model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')这一行修改为model, preprocess = load_from_name({model_path}, device=device, vision_model_name="ViT-B-16", text_model_name="RoBERTa-wwm-ext-base-chinese", input_resolution=224)。其中{model_path}替换为您训练好的模型文件的路径即可。

JeffMony commented 1 year ago

你好,我使用的是run_scripts/muge_finetune_vit-b-16_rbt-base.sh 来训练数据的。

image

train_texts.jsonl如下:

{"text_id": 1000, "text": "户外运动", "image_ids": [100000]}
{"text_id": 1001, "text": "潮人", "image_ids": [100004]}
{"text_id": 1002, "text": "自拍", "image_ids": [100008]}
{"text_id": 1003, "text": "快乐阳光", "image_ids": [100012]}
{"text_id": 1004, "text": "激情时尚", "image_ids": [100016]}
{"text_id": 1005, "text": "满满的书桌", "image_ids": [100020]}
{"text_id": 1006, "text": "可爱的玩具", "image_ids": [100024]}
{"text_id": 1007, "text": "明星的大头贴", "image_ids": [100028]}
{"text_id": 1008, "text": "毛绒玩具", "image_ids": [100032]}
{"text_id": 1009, "text": "卡通版的杯子", "image_ids": [100036]}
{"text_id": 1010, "text": "卡通版的相册", "image_ids": [100040]}
{"text_id": 1011, "text": "手机套", "image_ids": [100044]}
{"text_id": 1012, "text": "明星照片徽章", "image_ids": [100048]}
{"text_id": 1013, "text": "卡通抱枕", "image_ids": [100052]}
{"text_id": 1014, "text": "情侣相拥", "image_ids": [100056]}
{"text_id": 1015, "text": "时尚的情侣", "image_ids": [100060]}
{"text_id": 1016, "text": "帅气的男人", "image_ids": [100064]}
{"text_id": 1017, "text": "侧身的男人", "image_ids": [100068]}
{"text_id": 1018, "text": "时尚二人组", "image_ids": [100072]}
{"text_id": 1019, "text": "靠在沙发上的女人", "image_ids": [100076]}
{"text_id": 1020, "text": "烈焰红唇", "image_ids": [100080]}
{"text_id": 1021, "text": "甜蜜相视", "image_ids": [100084]}
{"text_id": 1022, "text": "性感女孩的脸庞", "image_ids": [100088]}
{"text_id": 1023, "text": "知性的女人", "image_ids": [100092]}
{"text_id": 1024, "text": "剪刀手", "image_ids": [100096]}
{"text_id": 1025, "text": "随性的生活照", "image_ids": [100100]}
{"text_id": 1026, "text": "可爱的扮相", "image_ids": [100104]}
{"text_id": 1027, "text": "撩起头发", "image_ids": [100108]}
{"text_id": 1028, "text": "清纯的女孩", "image_ids": [100112]}
{"text_id": 1029, "text": "激情时尚的女孩", "image_ids": [100116]}
{"text_id": 1030, "text": "撑着雨伞", "image_ids": [100120]}
{"text_id": 1031, "text": "靓丽背影", "image_ids": [100124]}
{"text_id": 1032, "text": "时尚达人", "image_ids": [100128]}

valid_texts.jsonl如下:

{"text_id": 10000, "text": "户外运动", "image_ids": [200000]}
{"text_id": 10001, "text": "潮人", "image_ids": [200004]}
{"text_id": 10002, "text": "自拍", "image_ids": [200008]}
{"text_id": 10003, "text": "快乐阳光", "image_ids": [200012]}
{"text_id": 10004, "text": "激情时尚", "image_ids": [200016]}
{"text_id": 10005, "text": "满满的书桌", "image_ids": [200020]}

我使用最终训练出来的模型来测试第一张图片,发现非常不准。

image

![Uploading image.png…]()

请问我是做错了什么了吗?

DtYXs commented 1 year ago

您好,建议您先根据预测及评估这部分所介绍的内容走一下流程,进而可以得到在您的valid数据集上的准确率来判断训练的结果。 根据您截图中的脚本来看您是在单卡上训练的,默认的一些超参数不适配单卡小batch size的训练,需要进行调整,由于对比学习的训练收敛和稳定性和总batch size相关,如果您的训练数据集足够大,建议您尽可能设置更大的batch size,适当减小一下学习率,在训练的时候关注一下log中输出的train和valid的loss及准确率的变化。 对于提升batch size,您可以在训练脚本中加入--grad-checkpointing来启动重计算策略,可以节省显存来进一步提升batch size。另外,如果您的硬件和环境支持的话,在配置好环境后,您可以启动我们最新适配的FlashAttention,在训练脚本中加入--use-flash-attention可以进一步设置更大的batch size(参见FlashAttention.md)。在尽可能在单卡上提升batch size后,您可以适当调整一下其余超参数,例如学习率lr和训练epoch数max_epochs,由于batch size较小,可以设置更小的lr比如1e-6这样,适当增大一些max_epochs,之后将训练得到的模型通过预测及评估流程来得到准确率。 不知道您列出的train_texts.jsonl的内容是否是全部的训练集数据,如果是的话感觉训练数据有点少,如果训练太久可能会产生过拟合而达不到较好的效果,需要针对数据精细调整一下学习率和训练step数等参数。您也可以尝试扩充一下训练数据的量来提升模型在您的数据集上的效果,或者在和您提供的数据中文本内容和长度较为相似的MUGE数据集上引入部分数据训练一下看看效果。

JeffMony commented 1 year ago

你好,我对CLIP这块算法非常感兴趣,我们可以加个好友吗? 我的微信是:LOVE_BigLi