Closed KelianM closed 4 days ago
thanks @KelianM you are rght its broken. We tested it with this option turned off I believe so it would be awesome to have it working!
Hi @KelianM!
Thanks for the detailed description.
I remember I tested Lag-Llama inference with and without the KV cache, and it gave the same results (and inference was much faster with the KV cache turned on), so I do not completely assimilate the issue. Can you please make a PR - then it'll be clearer for me.
Since you mention you get faster inference with "similar" accuracy, I wonder where the difference in accuracy comes from - do you have an idea? If possible, can you post the exact accuracy and speed of inference for a dataset (say, one of the benchmark datasets) before the fix, and after the fix?
How about without using the KV cache - do you get the said "similar" accuracy and what is the speed of inference? I know that without the KV cache, is_causal=True
is used, so I believe there should still be a difference in accuracy as per your description of the Causal Attention Error.
Hi @ashok-arjun,
On the slight differences in accuracy, I'm really not sure what is causing this. All I know is from my own debugging stepping through things the F.scaled_dot_product_attention
seemed to be what was causing the variance even after my fixes. I think this will really require a deep dive to understand further.
I tried testing the original version using your zero-shot notebook, making sure to seed for deterministic results. Using the aus_retail
dataset from the zero shot notebook, 30 samples, context length 32, batch size 4, seed 1234:
use_kv_cache=False
:
Inference duration: 13.91 seconds
CRPS: 0.07899402871357544
use_kv_cache=True
:
Inference duration: 8.56 seconds
CRPS: 0.08842500045081104
use_kv_cache=True
with fixes:
Inference duration: 10.92 seconds
CRPS: 0.08004362005563928
Would appreciate if someone could try reproduce this and test more extensively. I will make a PR for my changes but won't be able to spend much longer on this. There's still an accuracy difference but it's much smaller, I'm convinced it's to do with the F.scaled_dot_product_attention
though. The slight decrease in speed is also expected, since previously the cache wasn't resetting properly so the first step was skipping the full initialisation of the cache.
@KelianM the metrics are also slightly noisy inherently due to the sampling so small variations in different runs is normal i believe, as they are calculated from the empirical samples
@KelianM the metrics are also slightly noisy inherently due to the sampling so small variations in different runs is normal i believe, as they are calculated from the empirical samples
Yes I did seed it though using utils.utils.set_seed(seed)
and lightning.seed_everything(seed)
, and that allowed me to get the same results each time.
right... also note that gpu ops are also inherently non-deterministic and thus even if you seed things there is no guarantee on the GPU side... due to the way the cuda kernels are scheduled etc. The solution is to use deterministic mode in cuda, however that makes things super slow...
I see, maybe that's why the attention function gave slightly different results... Although I did test this on CPU a while back and saw similar effects
Added a PR, I think a good way to test this further might be to use a very high prediciton length - since if there are errors in kv_cache they will compound after each prediction. This will also make the speedup more visible since it scales with number of predictions.
also @KelianM during inference we do not really need the causal mask... is there a way to remove it?
You can just use the previous method where att_mask=None, is_causal=False
. If you look at the previous code though it was using causal attention even during inference. It was only when kv_cache was enabled that you used is_causal=False
.
You still need causal attention when working with the full sequence though. I think I have a more elegant solution than manually defining a mask. If you want to use the is_causal
flag rather than a mask you need to use causal attention for the first step with kv_cache, since you are using the full sequence, and then only afterwards once kv_cache has been initialised and we are only working with the last token set is_causal
to False. Will add it to the PR.
@KelianM you mean when fine-tuning? or what do you mean "when working with the full sequence" ?
yes please lets discuss on the PR
Hi @ashok-arjun,
On the slight differences in accuracy, I'm really not sure what is causing this. All I know is from my own debugging stepping through things the
F.scaled_dot_product_attention
seemed to be what was causing the variance even after my fixes. I think this will really require a deep dive to understand further.I tried testing the original version using your zero-shot notebook, making sure to seed for deterministic results. Using the
aus_retail
dataset from the zero shot notebook, 30 samples, context length 32, batch size 4, seed 1234:
use_kv_cache=False
: Inference duration: 13.91 seconds CRPS: 0.07899402871357544
use_kv_cache=True
: Inference duration: 8.56 seconds CRPS: 0.08842500045081104
use_kv_cache=True
with fixes: Inference duration: 10.92 seconds CRPS: 0.08004362005563928Would appreciate if someone could try reproduce this and test more extensively. I will make a PR for my changes but won't be able to spend much longer on this. There's still an accuracy difference but it's much smaller, I'm convinced it's to do with the
F.scaled_dot_product_attention
though. The slight decrease in speed is also expected, since previously the cache wasn't resetting properly so the first step was skipping the full initialisation of the cache.
Thank you for posting the results! I'll investigate the issue once I get some time.
Kv_cache as currently implemented appears to completely break forecasts. Even though it does function and is way faster, it gets terrible forecast accuracy at least on my dataset, and theoretically it should get the same results as without kv_cache. I identified and fixed 3 main issues in my fork which caused this:
self.y_cache
flag not being reset to false in thereset_cache
function. Currently it incorrectly sets they_cache
flag to false for each individual transformer block instead of for the whole model.F.scaled_dot_product_attention
in theCausalSelfAttention
module appears to give completely incorrect results even whenis_causal=True
is set when using kv_cache. I found a related issue on this function which suggested building your own causal attention mask usingtorch.nn.Transformer.generate_square_subsequent_mask
instead of using theis_causal
flag, and I found this fixed the issue if I reshaped that to a1 x seq_len
mask when using kv_cache, while theis_causal
flag just would not give good results with kv_cache. I think it does not build the attention mask correctly when you only have a single query vector (as is this case when using kv_cache), maybe it can only build square attention masks, but I'm still not sure on why this happens.I have implemented these fixes in my fork of the repository and it led to what was a 90 min forecast being cut down to roughly 20 mins with similar accuracy. Before proceeding with a pull request, I would appreciate it if you could just verify that kv_cache is indeed broken on your end and that you agree with the reasoning.