Closed ad8e closed 7 months ago
@ad8e hey Kevin
thanks for reporting
i quickly checked on a test script and it seems to be fine
import torch
from x_transformers import (
TransformerWrapper,
Decoder,
AutoregressiveWrapper
)
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 8,
depth = 1,
heads = 4
)
)
model = AutoregressiveWrapper(model)
prompts = torch.zeros((1, 1))
generated = model.generate(
prompts,
seq_len = 100,
temperature = 0.,
cache_kv = False
)
kv_cache_generated = model.generate(
prompts,
seq_len = 100,
temperature = 0.,
cache_kv = True
)
assert torch.allclose(generated, kv_cache_generated)
could you modify the script so that it breaks? perhaps you are using some hyperparameter that is incompatible with kv cache (would be good to put in a patch if so into can_cache_kv
logic)
logits
and logits2
in your code have different shapes. you need to compare logits
and logits2[:, -1:]
@ad8e as for your following offer, it is ok, as the library is model architecture specific. what you mention is all training related
@ad8e ah, got to the bottom of it Kevin
so it turns out the default (absolute positional embedding) is not kv cache friendly once you exceed the maximum sequence length (context window). however, it should still work when decoding from 1st token to the max context window size
i added an assert to prevent this, but also defaulted the enwik8 training script to use rotary positions, which is the preferred positional embeddings these days (llama), and kv cache friendly when exceeding context length.
ok, back to the holidays; have a great new years Kevin
Thanks, this fixes the issue and is better than what I would have PR'd. Sorry that I dumped two bad testcases and then went to sleep.
@ad8e as for your following offer, it is ok
To clarify, do you mean it is ok to do it, or "it is ok" as in it is not necessary?
have a great new years Kevin
You too!
Thanks, this fixes the issue and is better than what I would have PR'd. Sorry that I dumped two bad testcases and then went to sleep.
@ad8e as for your following offer, it is ok
To clarify, do you mean it is ok to do it, or "it is ok" as in it is not necessary?
have a great new years Kevin
You too!
it isn't necessary, not for this lib
If I turn the kv cache on, autoregressive generation produces junk.
Here's a diagnostic: replace autoregressive_wrapper.py with this one: autoregressive_wrapper.txt (rename to .py)
This is the diff:
Running this on a testcase, such as your enwiki8 example, the output is that the logits are indeed scrambled between kv cache on vs off.
When testing generation on a small dataset, like Karpathy's Tiny Shakespeare, we see that KV cache breaks the generation:
KV cache on:
KV cache off (by forcing
self.can_cache_kv = 0
):If you want, you can fix it, or you can also just wait and I'll figure out the cause and submit a PR in a week.
As an aside, I'm running some changes to the enwik8 example: LR schedule, removing gradient accumulation, increasing batch size, and turning on flash attention. This significantly improves the performance and is codesize-neutral. Would you welcome these changes?