ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.3k stars 898 forks source link

Enable custom masks as optional input for models for batch processing #1044

Open nath1295 opened 1 month ago

nath1295 commented 1 month ago

With a quick read of the code, it seems that the attention mask tensor is created on the fly during inference. The mask is then broadcasted to all the prompt token sequences in individual layers (normally it's one, but to allow batch inferences, we should not assume this). This might cause a problem during batch inference as we cannot mask the padded tokens for prompts with different lengths. One way to do it now is to process those pad tokens as well, but this will change the output. Just want to make sure my understanding here is correct.

It will be beneficial to be able to customise these mask tensors so that we can just fill the prompt cache with zeros for padded tokens and it won't affect the output. I am not sure if this is possible.

A simple example use case:

from mlx_lm import load
from mlx_lm.models.cache import make_prompt_cache
import mlx.core as mx

model, tokenizer = load('my/model/path')
pad_token_id = tokenizer.bos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id

# Get the lists of token ids for all the prompts
prompts = [
    'The weather is nice out there',
    'The weather is awful out there, and this is a longer prompt'
]
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]

# Get the masks for each token in all the prompts
prompt_lens = [len(pt) for pt in prompt_tokens]
max_prompt_len = max(prompt_lens)
mask = [[-1] * (max_prompt_len - n) + tks for tks, n in zip(prompt_tokens, prompt_lens)]
mask = (mx.array(mask) != -1).astype(mx.int16)

# Pad the shorter prompts
prompt_tokens = [[pad_token_id] * (max_prompt_len - n) + tks for tks, n in zip(prompt_tokens, prompt_lens)]
prompt_tokens = mx.array(prompt_tokens)

# Make the cache
cache = make_prompt_cache(model)

# Get the logits for the next token for each prompt
logits = model(prompt_tokens, mask=mask, cache=cache)

I realise this might take a lot of rework, but I am just wondering if this is possible?