alibaba / AliceMind

ALIbaba's Collection of Encoder-decoders from MinD (Machine IntelligeNce of Damo) Lab
Apache License 2.0
1.98k stars 291 forks source link

Inference of image captioning on single image #78

Closed ChandanVerma closed 1 year ago

ChandanVerma commented 1 year ago

How do i used mPlug to generate image captions on a single image?

huuquan1994 commented 1 year ago

@ChandanVerma Here is an example of how to load the mPLUG-Large-V2 model

Load the mPLUG from checkpoint

import torch.nn as nn
import yaml
from models.model_caption_mplug import MPLUG
from models.tokenization_bert import BertTokenizer
from models.vit import interpolate_pos_embed, resize_pos_embed

# Load the config file
config_path = 'AliceMind/mPLUG/configs/caption_mplug_large.yaml'
config = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)

config["min_length"] = 8
config["max_length"] = 25
config["add_object"] = True
config["beam_size"] = 5
config['text_encoder'] = 'bert-base-uncased'
config['text_decoder'] = 'bert-base-uncased'

# initilize mPLUG-Large-V2 model
mplug_tokenizer = BertTokenizer.from_pretrained(config['text_encoder'])
mPLUG_model = MPLUG(config=config, tokenizer=mplug_tokenizer)

mPLUG_ckpt_path = '/path/to/mplug_large_v2.pth' # replace the path of your ckpt here
checkpoint = torch.load(mPLUG_ckpt_path, map_location='cpu')
state_dict = checkpoint['model']

# reshape positional embedding to accomodate for image resolution change
# ref: https://github.com/alibaba/AliceMind/blob/main/mPLUG/caption_mplug.py#L227
if config["clip_name"] == "ViT-B-16":
    num_patches = int(config["image_res"] * config["image_res"]/(16*16))
elif config["clip_name"] == "ViT-L-14":
   num_patches = int(config["image_res"] * config["image_res"]/(14*14))

pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())

pos_embed = resize_pos_embed(state_dict['visual_encoder.visual.positional_embedding'].unsqueeze(0),
                                         pos_embed.unsqueeze(0))
state_dict['visual_encoder.visual.positional_embedding'] = pos_embed

# load and move the mPLUG-Large model to GPU
mPLUG_model.load_state_dict(state_dict, strict=False); # adding ; to avoid printing long messages
mPLUG_model = mPLUG_model.to('cuda').eval();

Inference from single input

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

# define preprocess function
mPLUG_transform = transforms.Compose([
    transforms.Resize((config['image_res'], config['image_res']), 
                      interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                         (0.26862954, 0.26130258, 0.27577711)),
])

# load image
img_path = 'dog.jpg' # path to your image here
image = Image.open(img_path).convert('RGB')

# run inference
mplug_image = mPLUG_transform(image).unsqueeze(0) # [1, 3, image_res, image_res]
mplug_image = mplug_image.to('cuda')

topk_ids, topk_probs = mPLUG_model(mplug_image, None, train=False)

output_caption = mplug_tokenizer.decode(topk_ids[0][0]).replace("[SEP]", "").replace("[CLS]", "").replace("[PAD]", "").strip()

print(output_caption)
ChandanVerma commented 1 year ago

hi @huuquan1994 is it possible to generate video captions using a single video file as well? Thanks in advance.

huuquan1994 commented 1 year ago

@ChandanVerma I don't have time to test on videos but you can try to refer to https://github.com/alibaba/AliceMind/blob/main/mPLUG/videocap_mplug.py and see how to load both videos and pre-trained model. Take a look at the video dataset that uses to load video inputs (here: https://github.com/alibaba/AliceMind/blob/main/mPLUG/dataset/video_dataset.py)

ChandanVerma commented 1 year ago

Sure.. thanks for the heads up :+1: