CaraJ7 / CoMat

Official code for 💫CoMat: Aligning Text-to-Image Diffusion Model with Image-to-Text Concept Matching
https://caraj7.github.io/comat/
122 stars 5 forks source link

how to write L_cap in pytorch? #2

Open tian-2024 opened 5 months ago

tian-2024 commented 5 months ago

Can BLIP receives tokens w1~w_{t-1} and an Image, then output the probility of w_t?

I thought it's an image caption model, so it can only get an image and output the description.

$$\prod_{i=1}^L p_C(wi|\mathcal{I}, w{1:i-1})$$

CaraJ7 commented 5 months ago

Hi @jaycee-tian, thanks for your interest! Actually, $L_{cap}$ is exactly the loss for training an image captioning model. A simplified code snippet looks like:


import torch
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16).to("cuda")

text = "a photography of a polar bear"
inputs = processor(generated_image_of_a_polar_bear, text, return_tensors="pt").to("cuda", torch.float16)

l_cap = model(**inputs).loss