OFA-Sys / Chinese-CLIP

Chinese version of CLIP which achieves Chinese cross-modal retrieval and representation generation.
MIT License
4.21k stars 439 forks source link

关于自己训练数据集模型推理的问题 #239

Open ipeaking opened 8 months ago

ipeaking commented 8 months ago

我基于clip_cn_vit-h-14.pt模型训练了一个自己的模型,并使用下面的代码对模型进行了保存 for epoch in range(num_epochs): total_loss = 0 total_loss2 = 0

model.train()
textAdapter.train()
imageAdapter.train()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

for i, (images, texts, text_rows, eos_index) in enumerate(dataloader):
    # print(text_rows)
    optimizer.zero_grad()
    images = images.to(device)
    texts = texts.to(device)
    logits_per_image, logits_per_text = similarity.get_similarity3(model, images, texts, imageAdapter)
    # logits_per_image, logits_per_text =get_similarity(model,images, texts,textAdapter,imageAdapter)
    ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
    loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    total_loss2 += loss.item()

    if (i + 1) % print_every == 0:
        avg_loss = total_loss2 / print_every
        print(f"Epoch {epoch + 1}/{num_epochs}, Step {i + 1}/{len(dataloader)}, Average Loss: {avg_loss}")
        total_loss2 = 0

avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss}")

torch.save(model, 'trained_model.pth')

但是我不知道该如何推理他,谁能告诉我,谢谢。

DtYXs commented 4 months ago

可以参考预测及评估部分的流程,将resume替换为保存的模型路径