Closed TsuTikgiau closed 4 months ago
Hello Team, Thanks for the update! I notice that the current version doesn't evaluate the data in a batch style yet. Here is an example function previous done by my friend for this. We can include something similar
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
self.device = device
self.model = model
self.vis_processor = vis_processor
self.conv = CONV_VISION.copy()
self.img_list = []
self.raw_answers = []
# stop_words_ids = [torch.tensor([835]).to(self.device),
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def reset(self):
self.conv.messages = []
self.img_list = []
# self.img_list = [img for img in self.conv.system_img]
self.raw_answers = []
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
# conv.append_message(None, text)
def answer(self, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1):
self.conv.append_message(self.conv.roles[1], None)
embs = self.get_context_emb()
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
do_sample=False,
)
output_token = outputs[0]
if output_token[0] == 0:
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
self.raw_answers.append(output_text)
output_text = output_text.split('</s>')[0] # remove the stop sign '###'
output_text = output_text.replace("<s>","")
output_text = output_text.split(r'[/INST]')[-1].strip()
self.conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def upload_img(self, image):
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, Image.Image):
raw_image = image
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
elif isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
image = image.to(self.device)
image_emb, _ = self.model.encode_img(image)
self.img_list.append(image_emb)
self.conv.append_message(self.conv.roles[0], "<Img><ImageHere></Img>")
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
try:
seg_embs = [self.model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
except:
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
chat = Chat(model, vis_processor)
chat.conv.system = ""
def BatchGeneration(images, texts, max_new_tokens=10):
convs = [CONV_VISION.copy() for _ in range(eval_dataloader.batch_size)]
[chat.ask('<Img><ImageHere></Img> {} '.format(text), conv) for conv, text in zip(convs, texts)]
[conv.append_message(conv.roles[1], None) for conv in convs]
# [conv.append_message(None, None) for conv in convs]
with torch.no_grad():
image_embs, _ = chat.model.encode_img(images.to(chat.device).half())
image_lists = [[image_emb[None]] for image_emb in image_embs]
batch_embs = [chat.get_context_emb(conv, img_list) for conv, img_list in zip(convs, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
# outputs = chat.emb_generate(embs, max_new_tokens=20, attention_mask=attn_mask)
with torch.no_grad():
outputs = chat.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=args.max_new_tokens,
attention_mask=attn_mask,
num_beams=args.num_beams,
do_sample=False,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = chat.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
output_texts = output_texts.split('</s>')[0] # remove the stop sign '###'
output_texts = output_texts.replace("<s>","")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@TsuTikgiau Thank you for the example, we will try to incorporate this into our evaluation datasets.
Hi @TsuTikgiau,
In the current setup, we are providing the models with raw images and corresponding prompts. So, since the prompts have different lengths, creating a batch is difficult. One solution is to preprocess images and prompts using the model's image and text feature extractors. However, some of the models might not provide these extractors directly. For e.g., I am working on blip2 T5 model, and using their extractors is very tricky. I have gone through their code, and it seems like their model does the batching given preprocessed images and a list of prompts. However, this may not be true for all the models that we are evaluating.
Hello team, the evaluation of some big datasets can be up to 30+ hours on A100, but this can be reduced to 30 min if the batch size is 64. Thus, I suggest to include the batched inference for models that are not so hard to include
@TsuTikgiau Hi, I have tried to do the batch inference for BLIP2 based models. Since, the blip2 models handle the batching internally. I have used the previous dataset definition and loaded them with batch size 1 because with a bigger batch size, batching prompts would have been difficult. So, after collecting all the data through the loader, I have then created the batches and passed the batches to the model to generate. I have done this for the Slake dataset. You can find the dataset and the corresponding loader here - https://github.com/pliang279/HEMM/blob/main/hemm/data/slake_dataset.py https://github.com/pliang279/HEMM/blob/main/hemm/data/slake_loader.py
On colab pro+ with A100 GPU, the runtime with single batch size is around 4 minutes but with a batch size of 16, it is 1.5 minutes. A bigger batch size was causing memory error on the GPU.
I believe that if the prompts can be batched then the need for loading all the samples with batch size 1 will be unnecessary. But, in order to create batch of the prompts, I would need to do some tokenization on my end, and define some padding tokens so that all the prompts have same length in the batch. But since we want the model to handle all the tokenization, I don't think this is a good option.
For other models, I have to check how can this be done.
Hello Team, below are some issues I found for now:
image_dir = os.path.join(self.dataset_dir, 'data')