Open andsteing opened 7 months ago
When computing the log probabilities for
prompts = ( 'the sky is blue', 'the sky is pink', 'the sky is bacon', )
I get very different values, depending on whether I use a cache=RotatingBufferCache(...) or cache=None:
cache=RotatingBufferCache(...)
cache=None
use_cache=False: [ 1] <s> -> [ 272] ▁the : -7.41 0.06% [ 272] ▁the -> [ 7212] ▁sky : -0.98 37.41% [ 7212] ▁sky -> [ 349] ▁is : -0.19 82.40% [ 349] ▁is -> [ 5045] ▁blue : -0.81 44.68% [ 1] <s> -> [ 272] ▁the : -7.41 0.06% [ 272] ▁the -> [ 7212] ▁sky : -0.98 37.41% [ 7212] ▁sky -> [ 349] ▁is : -0.19 82.40% [ 349] ▁is -> [ 12937] ▁pink : -2.45 8.59% [ 1] <s> -> [ 272] ▁the : -7.41 0.06% [ 272] ▁the -> [ 7212] ▁sky : -0.98 37.41% [ 7212] ▁sky -> [ 349] ▁is : -0.19 82.40% [ 349] ▁is -> [ 287] ▁b : -5.00 0.67% [ 287] ▁b -> [ 10364] acon : -0.04 96.17% use_cache=True: [ 1] <s> -> [ 272] ▁the : -9.24 0.01% [ 272] ▁the -> [ 7212] ▁sky : -7.37 0.06% [ 7212] ▁sky -> [ 349] ▁is : -1.16 31.46% [ 349] ▁is -> [ 5045] ▁blue : -2.39 9.13% [ 1] <s> -> [ 272] ▁the : -9.24 0.01% [ 272] ▁the -> [ 7212] ▁sky : -7.37 0.06% [ 7212] ▁sky -> [ 349] ▁is : -1.16 31.46% [ 349] ▁is -> [ 12937] ▁pink : -4.82 0.81% [ 1] <s> -> [ 272] ▁the : -9.24 0.01% [ 272] ▁the -> [ 7212] ▁sky : -7.37 0.06% [ 7212] ▁sky -> [ 349] ▁is : -1.16 31.46% [ 349] ▁is -> [ 287] ▁b : -7.59 0.05% [ 287] ▁b -> [ 10364] acon : -4.41 1.21%
The values without cache do not make any sense (the values with cache seem reasonable though).
Why is this? How can I use the model without cache?
Full code is in this Colab: https://colab.research.google.com/drive/1lNk_JgFFAakTRtEVkpxQ42jlGCygwfSb
When computing the log probabilities for
I get very different values, depending on whether I use a
cache=RotatingBufferCache(...)
orcache=None
:The values without cache do not make any sense (the values with cache seem reasonable though).
Why is this? How can I use the model without cache?
Full code is in this Colab: https://colab.research.google.com/drive/1lNk_JgFFAakTRtEVkpxQ42jlGCygwfSb
Show code from Colab
```python def get_logprobs(model, tokenizer, prompts, *, use_cache): """Returns `(encoded_prompts, logprobs)`, optionally using the cache.""" encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts[:3]] seqlens = [len(x) for x in encoded_prompts] concatenated_prompts = torch.tensor(sum(encoded_prompts, []), device=model.device, dtype=torch.long) if use_cache: sliding_window = model.args.sliding_window sliding_window = min(max(seqlens), sliding_window) cache = mistral.cache.RotatingBufferCache( model.args.n_layers, model.args.max_batch_size, sliding_window, model.args.n_kv_heads, model.args.head_dim, ) cache.to(device=model.device, dtype=model.dtype) cache.reset() else: cache = None prelogits = model.forward( concatenated_prompts, seqlens=seqlens, cache=cache, ) logits = torch.log_softmax(prelogits, dim=-1) logprobs = [[] for _ in range(len(prompts))] offset = 0 for i_seq, sequence in enumerate(encoded_prompts): logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)]) offset += len(sequence) return encoded_prompts, logprobs def print_logprobs(id2token, encoded_prompts, logprobs): """prints `(encoded_prompts, logprobs)` tokens / transition probabilities.""" for i, t in enumerate(encoded_prompts): for j, (t1, t2) in enumerate(zip(t, t[1:])): logit = float(logprobs[i][j]) print( f'[{t1:6}] {id2token(t1):8} ' f'-> [{t2:6}] {id2token(t2):8}: ' f'{logit:7.2f} ' f'{np.exp(logit):6.2%}' ) print() prompts = ( 'the sky is blue', 'the sky is pink', 'the sky is bacon', ) for use_cache in (False, True): print(f'use_cache={use_cache}:\n') print_logprobs(tokenizer._model.id_to_piece, *get_logprobs(model, tokenizer, prompts, use_cache=use_cache)) ```