Luodian / Otter

🦦 Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following and in-context learning ability.
https://otter-ntu.github.io/
MIT License
3.54k stars 242 forks source link

Comment about MPT-7B log-likelihood evaluation #253

Open vishaal27 opened 1 year ago

vishaal27 commented 1 year ago

Hey, I just observed something that I thought was interesting while running some log-likelihood based evals on the MPT-7B based Otter, and wanted to share with you to confirm if you had observed similarly.

I adopted your implementation of the log-likelihood evals for classification tasks from here: https://github.com/Luodian/Otter/blob/93a0d5e2d6f114e70d6e327b975bce4671cea9bb/pipeline/eval/evaluate.py#L934

However, I noted that this wasn't giving me the expected performance as the logits were being scaled incorrectly after the softmax due to some weird batching effects. That is, for different evaluation batch sizes, I was getting inconsistent results for the predictions. I also noticed that you mentioned in the eval readme that perhaps the MPT-7B evals were not as you were expecting.

The interesting observation I had was that MPT-7B performs some weird logit prob handling while batching where it assigns a very low prob to the last token before the padding tokens. To counter this, I simply filtered out the last token's prob before the padding tokens, and this simple fix produced consistent predictions across different batch sizes.

So basically the modifications I made to your classification eval script was that I changed (I only tested this without the kv-caching) this: https://github.com/Luodian/Otter/blob/93a0d5e2d6f114e70d6e327b975bce4671cea9bb/pipeline/eval/evaluate.py#L1155-L1161 to:

logits = outputs.logits.detach().float() 
probs = torch.softmax(logits, dim=-1) 

# get probability of the generated class name tokens 
gen_probs = probs[:, ctx_and_prompt_input_ids.shape[1] - 1 : _lang_x.shape[1], :] 
gen_probs = torch.gather(gen_probs, 2, classname_tokens[:, :, None]).squeeze(-1).cpu() 
# filter out all the probs corresponding to the last token before the padding token 
# since these tokens get unreasonably low probs during batching
# implemented with a simple hack of filtering out very low probs
threshold = 1e-20
mask = gen_probs > threshold
gen_probs_masked = torch.where(mask, gen_probs, torch.tensor(1.0).to(gen_probs.device))
class_prob = torch.prod(gen_probs_masked, 1).numpy()

This worked well for me for the classification tasks I am working with, and I also tested this across different batch sizes with consistent results. Please let me know if this is a known issue and if you've encountered this before?