pliang279 / HEMM

Holistic evaluation of multimodal foundation models
MIT License
36 stars 0 forks source link

Tiny issues #13

Closed TsuTikgiau closed 4 months ago

TsuTikgiau commented 1 year ago

Hello Team, below are some issues I found for now:

  1. the Evaluation of MiniGPT4 on the dataset newyorkercartoon is very slow (more than 90sec/it) on my A100. Is this expected?
  2. The evaluation currently runs with a batch size of 1. We should make it work for batched evaluation.
  3. Tiny bugs:
    • L48 in hemm/data/hateful_memes_dataset.py should be image_dir = os.path.join(self.dataset_dir, 'data')
    • package openpyxl should be included in the requirement.txt
    • L69 in hemm/data/scienceQA_dataset.py will raise an error "UnboundLocalError: local variable 'random_item' referenced before assignment"
talha1503 commented 1 year ago
  1. Yes It it showing me between 70s/it to 110s/it on google colab pro+ with a A100 gpu.
  2. Have fixed those bugs in the latest commit.
TsuTikgiau commented 1 year ago

Hello Team, Thanks for the update! I notice that the current version doesn't evaluate the data in a batch style yet. Here is an example function previous done by my friend for this. We can include something similar

class Chat:
    def __init__(self, model, vis_processor, device='cuda:0'):
        self.device = device
        self.model = model
        self.vis_processor = vis_processor

        self.conv = CONV_VISION.copy()
        self.img_list = []
        self.raw_answers = []

        # stop_words_ids = [torch.tensor([835]).to(self.device),
        #                   torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
        stop_words_ids = [torch.tensor([2]).to(self.device)]
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    def reset(self):
        self.conv.messages = []
        self.img_list = []
        # self.img_list = [img for img in self.conv.system_img]
        self.raw_answers = []

    def ask(self, text, conv):
        if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
                and conv.messages[-1][1][-6:] == '</Img>':  # last message is image.
            conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
        else:
            conv.append_message(conv.roles[0], text)
            # conv.append_message(None, text)

    def answer(self, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1):
        self.conv.append_message(self.conv.roles[1], None)
        embs = self.get_context_emb()
        outputs = self.model.llama_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=self.stopping_criteria,
            num_beams=num_beams,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
            do_sample=False,
        )
        output_token = outputs[0]
        if output_token[0] == 0:
            output_token = output_token[1:]
        output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
        self.raw_answers.append(output_text)
        output_text = output_text.split('</s>')[0]  # remove the stop sign '###'
        output_text = output_text.replace("<s>","")
        output_text = output_text.split(r'[/INST]')[-1].strip()
        self.conv.messages[-1][1] = output_text
        return output_text, output_token.cpu().numpy()

    def upload_img(self, image):
        if isinstance(image, str):  # is a image path
            raw_image = Image.open(image).convert('RGB')
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        elif isinstance(image, Image.Image):
            raw_image = image
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        elif isinstance(image, torch.Tensor):
            if len(image.shape) == 3:
                image = image.unsqueeze(0)
            image = image.to(self.device)

        image_emb, _ = self.model.encode_img(image)
        self.img_list.append(image_emb)
        self.conv.append_message(self.conv.roles[0], "<Img><ImageHere></Img>")
        msg = "Received."
        # self.conv.append_message(self.conv.roles[1], msg)
        return msg

    def get_context_emb(self, conv, img_list):
        prompt = conv.get_prompt()
        prompt_segs = prompt.split('<ImageHere>')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.llama_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        try:
            seg_embs = [self.model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        except:
            seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs

chat = Chat(model, vis_processor)
chat.conv.system = ""

def BatchGeneration(images, texts, max_new_tokens=10):
        convs = [CONV_VISION.copy() for _ in range(eval_dataloader.batch_size)]
        [chat.ask('<Img><ImageHere></Img> {} '.format(text), conv) for conv, text in zip(convs, texts)]
        [conv.append_message(conv.roles[1], None) for conv in convs]
        # [conv.append_message(None, None) for conv in convs]

        with torch.no_grad():
            image_embs, _ = chat.model.encode_img(images.to(chat.device).half())
        image_lists = [[image_emb[None]] for image_emb in image_embs]

        batch_embs = [chat.get_context_emb(conv, img_list) for conv, img_list in zip(convs, image_lists)]    

        batch_size = len(batch_embs)
        max_len = max([emb.shape[1] for emb in batch_embs])
        emb_dim = batch_embs[0].shape[2]
        dtype = batch_embs[0].dtype
        device = batch_embs[0].device

        embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
        attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
        for i, emb in enumerate(batch_embs):
            emb_len = emb.shape[1]
            embs[i, -emb_len:] = emb[0]
            attn_mask[i, -emb_len:] = 1

    #     outputs = chat.emb_generate(embs, max_new_tokens=20, attention_mask=attn_mask)
        with torch.no_grad():
            outputs = chat.model.llama_model.generate(
                        inputs_embeds=embs,
                        max_new_tokens=args.max_new_tokens,
                        attention_mask=attn_mask,
                        num_beams=args.num_beams,
                        do_sample=False,
                )
        answers = []
        for output_token in outputs:
            if output_token[0] == 0:
                output_token = output_token[1:]
            output_texts = chat.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
            output_texts = output_texts.split('</s>')[0]  # remove the stop sign '###'
            output_texts = output_texts.replace("<s>","")
            output_texts = output_texts.split(r'[/INST]')[-1].strip()
            answers.append(output_texts)

        return answers
akshayg08 commented 1 year ago

@TsuTikgiau Thank you for the example, we will try to incorporate this into our evaluation datasets.

akshayg08 commented 1 year ago

Hi @TsuTikgiau,

In the current setup, we are providing the models with raw images and corresponding prompts. So, since the prompts have different lengths, creating a batch is difficult. One solution is to preprocess images and prompts using the model's image and text feature extractors. However, some of the models might not provide these extractors directly. For e.g., I am working on blip2 T5 model, and using their extractors is very tricky. I have gone through their code, and it seems like their model does the batching given preprocessed images and a list of prompts. However, this may not be true for all the models that we are evaluating.

TsuTikgiau commented 1 year ago

Hello team, the evaluation of some big datasets can be up to 30+ hours on A100, but this can be reduced to 30 min if the batch size is 64. Thus, I suggest to include the batched inference for models that are not so hard to include

akshayg08 commented 1 year ago

@TsuTikgiau Hi, I have tried to do the batch inference for BLIP2 based models. Since, the blip2 models handle the batching internally. I have used the previous dataset definition and loaded them with batch size 1 because with a bigger batch size, batching prompts would have been difficult. So, after collecting all the data through the loader, I have then created the batches and passed the batches to the model to generate. I have done this for the Slake dataset. You can find the dataset and the corresponding loader here - https://github.com/pliang279/HEMM/blob/main/hemm/data/slake_dataset.py https://github.com/pliang279/HEMM/blob/main/hemm/data/slake_loader.py

On colab pro+ with A100 GPU, the runtime with single batch size is around 4 minutes but with a batch size of 16, it is 1.5 minutes. A bigger batch size was causing memory error on the GPU.

I believe that if the prompts can be batched then the need for loading all the samples with batch size 1 will be unnecessary. But, in order to create batch of the prompts, I would need to do some tokenization on my end, and define some padding tokens so that all the prompts have same length in the batch. But since we want the model to handle all the tokenization, I don't think this is a good option.

For other models, I have to check how can this be done.