mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.37k stars 817 forks source link

Why does `cache=None` produce different outputs? #88

Open andsteing opened 7 months ago

andsteing commented 7 months ago

When computing the log probabilities for

prompts = (
    'the sky is blue',
    'the sky is pink',
    'the sky is bacon',
)

I get very different values, depending on whether I use a cache=RotatingBufferCache(...) or cache=None:

use_cache=False:

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [  5045] ▁blue   :   -0.81 44.68%

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [ 12937] ▁pink   :   -2.45  8.59%

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [   287] ▁b      :   -5.00  0.67%
[   287] ▁b       -> [ 10364] acon    :   -0.04 96.17%

use_cache=True:

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [  5045] ▁blue   :   -2.39  9.13%

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [ 12937] ▁pink   :   -4.82  0.81%

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [   287] ▁b      :   -7.59  0.05%
[   287] ▁b       -> [ 10364] acon    :   -4.41  1.21%

The values without cache do not make any sense (the values with cache seem reasonable though).

Why is this? How can I use the model without cache?

Full code is in this Colab: https://colab.research.google.com/drive/1lNk_JgFFAakTRtEVkpxQ42jlGCygwfSb

Show code from Colab ```python def get_logprobs(model, tokenizer, prompts, *, use_cache): """Returns `(encoded_prompts, logprobs)`, optionally using the cache.""" encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts[:3]] seqlens = [len(x) for x in encoded_prompts] concatenated_prompts = torch.tensor(sum(encoded_prompts, []), device=model.device, dtype=torch.long) if use_cache: sliding_window = model.args.sliding_window sliding_window = min(max(seqlens), sliding_window) cache = mistral.cache.RotatingBufferCache( model.args.n_layers, model.args.max_batch_size, sliding_window, model.args.n_kv_heads, model.args.head_dim, ) cache.to(device=model.device, dtype=model.dtype) cache.reset() else: cache = None prelogits = model.forward( concatenated_prompts, seqlens=seqlens, cache=cache, ) logits = torch.log_softmax(prelogits, dim=-1) logprobs = [[] for _ in range(len(prompts))] offset = 0 for i_seq, sequence in enumerate(encoded_prompts): logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)]) offset += len(sequence) return encoded_prompts, logprobs def print_logprobs(id2token, encoded_prompts, logprobs): """prints `(encoded_prompts, logprobs)` tokens / transition probabilities.""" for i, t in enumerate(encoded_prompts): for j, (t1, t2) in enumerate(zip(t, t[1:])): logit = float(logprobs[i][j]) print( f'[{t1:6}] {id2token(t1):8} ' f'-> [{t2:6}] {id2token(t2):8}: ' f'{logit:7.2f} ' f'{np.exp(logit):6.2%}' ) print() prompts = ( 'the sky is blue', 'the sky is pink', 'the sky is bacon', ) for use_cache in (False, True): print(f'use_cache={use_cache}:\n') print_logprobs(tokenizer._model.id_to_piece, *get_logprobs(model, tokenizer, prompts, use_cache=use_cache)) ```