AILab-CVC / SEED

Official implementation of SEED-LLaMA (ICLR 2024).
https://ailab-cvc.github.io/seed
Other
515 stars 30 forks source link

How to force model to generate image? #17

Open haochuan-li opened 6 months ago

haochuan-li commented 6 months ago

Hi! Great work.

image

I see there's a "force image generation" option in the gradio demo. I wonder how to implement this in code? Can anyone enlighten me on this?

Thanks.

sijeh commented 5 months ago

Sorry for the late reply. Force image generation can be achieved by manually adding BOI token (Begin of image). The code can be found in the following link: https://github.com/AILab-CVC/SEED/blob/369b3ebd9727932d3f67f730bc86bcdf02312a77/gradio_demo/seed_llama_flask.py#L159

haochuan-li commented 5 months ago

Thanks for the reply!

@sijeh, I have another question related to the zero-shot retrieval evaluation. I cannot reproduce Table1 results in SEED-LLaMA paper.

Here's my code preparing Text Embedding and Image Embedding for Flickr30k

"""
Setting: Using Seed-LLaMA Tokenizer 2
"""
import hydra
from omegaconf import OmegaConf
from lavis.models import load_model
device = 'cuda'

tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer_hf.yaml'
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
seed_tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=False)

"""Preparing Flickr Text Embedding, simply follow blip2 retrieval"""
blip2_model = load_model("blip2", "pretrain")
blip2_model.eval().to(device)

text_emb = []
blip_text = blip2_model.tokenizer(captions, padding='max_length', truncation=True, max_length=32, return_tensors='pt')

blip_dataset = TextDataset(blip_text)
blip_dataloader = DataLoader(blip_dataset, 
                                shuffle=False, 
                                drop_last=False, 
                                num_workers=8,
                                pin_memory=True, 
                                batch_size=args.batch_size)

for (input_ids, attention_mask) in tqdm(blip_dataloader, desc='text', unit='text'):
    qformer_output = blip2_model.Qformer.bert(input_ids.to(device), attention_mask=attention_mask.to(device), return_dict=True).last_hidden_state[:,0,:]
    text_emb.append(qformer_output.detach().cpu())
text_emb = torch.concat(text_emb) # Text Emb for Retrieval, shape=[5000, 768]

"""Preparing Flickr Image Embedding"""
causal_code_pt = []
causal_emb_pt = []
for im in tqdm(imgs_gt, desc="tokenizing img", unit='img'):
    _, causal_code, causal_emb = seed_tokenizer.encode_image(image_torch=transform(im).to(device))
    causal_code_pt.append(causal_code[0][-1].squeeze())  # take the final embedding
    causal_emb_pt.append(causal_emb[0][-1].squeeze()) # take the final embedding

causal_code_pt = torch.stack(causal_code_pt) # Causal Code For Retrieval, shape=[1000, 768]
causal_emb_pt = torch.stack(causal_emb_pt) # Causal Emb For Retrieval, shape=[1000,768]

"""
The Detail about how to get causal code and causal emb, 
I modified the code in models/seed_qformer/qformer_quantizer.py
"""

def get_codebook_indices(self, image):
    with torch.no_grad():
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        print("image embeds", image_embeds.shape) # [1,257,1408]
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        ) 
        # query_output hidden shape=[1,32,768]
        # query output down shape=[1,32,32]
        # query output up shape=[1,32,768]

        query_output_down = self.encode_task_layer(query_output.last_hidden_state)
        quant, loss_embed, embed_ind = self.quantize(query_output_down)
        embed_ind = embed_ind.reshape(quant.shape[0], -1)

        query_output_up = self.decode_task_layer(quant)
    return embed_ind, query_output_up, query_output.last_hidden_state

"""Compute Similarity Matrix"""
causal_code /= causal_code.norm(dim=-1, keepdim=True)
causal_emb /= causal_emb.norm(dim=-1, keepdim=True)

blip_causal_code_sim = (text_emb @ causal_code.T) 
blip_causal_emb_sim = (text_emb @ causal_emb.T)

Results in paper

image

Reproduced Results

image

Question

image

I'm not sure whether this is the right way to get the text embedding and image embedding illustrated in the SEED-LLaMA paper. Please correct me if I'm wrong.

Looking Forward to your reply.

Thanks