csxmli2016 / w-plus-adapter

[CVPR 2024] When StyleGAN Meets Stable Diffusion: a W+ Adapter for Personalized Image Generation
Other
114 stars 8 forks source link

'utilizing BLIP2 to obtain captions #6

Open gaoyixuan111 opened 4 months ago

gaoyixuan111 commented 4 months ago

"Thank you very much for your excellent work. Could you please share the code or readme for 'utilizing BLIP2 to obtain captions'?"

csxmli2016 commented 4 months ago

"Thank you very much for your excellent work. Could you please share the code or readme for 'utilizing BLIP2 to obtain captions'?"

A simple example:

from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import os
import os.path as osp
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
)  # doctest: +IGNORE_RESULT

wild_path = './in-the-wild' #FFHQ in the wild path
caption_path = './in-the-wild-Caption-BLIP2' # save path
os.makedirs(caption_path, exist_ok=True)
img_lists = os.listdir(wild_path)
img_lists.sort()

for img_name in tqdm(img_lists):
    img_path = osp.join(wild_path, img_name)
    try:
        image = Image.open(img_path)
        inputs = processor(images=image, text="Question: Describe the image in one sentence, including facial attributes such as hairstyle, hair color, expression, dress style, dress color, as well as background details. Answer:", return_tensors="pt").to(device, torch.float16)

        ts = []
        generated_ids = model.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        ts.append(generated_text)
        print(ts)
        with open(osp.join(caption_path, img_name[:-4]+'.txt'), 'w') as f:
            for gt in ts:
                f.write(gt + '\n')
    except:
        print(img_name)
        continue