time-series-foundation-models / lag-llama

Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting
Apache License 2.0
1.08k stars 121 forks source link

Broken kv_cache and fixes. #81

Closed KelianM closed 4 days ago

KelianM commented 1 week ago

Kv_cache as currently implemented appears to completely break forecasts. Even though it does function and is way faster, it gets terrible forecast accuracy at least on my dataset, and theoretically it should get the same results as without kv_cache. I identified and fixed 3 main issues in my fork which caused this:

I have implemented these fixes in my fork of the repository and it led to what was a 90 min forecast being cut down to roughly 20 mins with similar accuracy. Before proceeding with a pull request, I would appreciate it if you could just verify that kv_cache is indeed broken on your end and that you agree with the reasoning.

kashif commented 1 week ago

thanks @KelianM you are rght its broken. We tested it with this option turned off I believe so it would be awesome to have it working!

ashok-arjun commented 1 week ago

Hi @KelianM!

Thanks for the detailed description.

KelianM commented 1 week ago

Hi @ashok-arjun,

On the slight differences in accuracy, I'm really not sure what is causing this. All I know is from my own debugging stepping through things the F.scaled_dot_product_attention seemed to be what was causing the variance even after my fixes. I think this will really require a deep dive to understand further.

I tried testing the original version using your zero-shot notebook, making sure to seed for deterministic results. Using the aus_retail dataset from the zero shot notebook, 30 samples, context length 32, batch size 4, seed 1234:

use_kv_cache=False: Inference duration: 13.91 seconds CRPS: 0.07899402871357544

use_kv_cache=True: Inference duration: 8.56 seconds CRPS: 0.08842500045081104

use_kv_cache=True with fixes: Inference duration: 10.92 seconds CRPS: 0.08004362005563928

Would appreciate if someone could try reproduce this and test more extensively. I will make a PR for my changes but won't be able to spend much longer on this. There's still an accuracy difference but it's much smaller, I'm convinced it's to do with the F.scaled_dot_product_attention though. The slight decrease in speed is also expected, since previously the cache wasn't resetting properly so the first step was skipping the full initialisation of the cache.

kashif commented 1 week ago

@KelianM the metrics are also slightly noisy inherently due to the sampling so small variations in different runs is normal i believe, as they are calculated from the empirical samples

KelianM commented 1 week ago

@KelianM the metrics are also slightly noisy inherently due to the sampling so small variations in different runs is normal i believe, as they are calculated from the empirical samples

Yes I did seed it though using utils.utils.set_seed(seed) and lightning.seed_everything(seed), and that allowed me to get the same results each time.

kashif commented 1 week ago

right... also note that gpu ops are also inherently non-deterministic and thus even if you seed things there is no guarantee on the GPU side... due to the way the cuda kernels are scheduled etc. The solution is to use deterministic mode in cuda, however that makes things super slow...

KelianM commented 1 week ago

I see, maybe that's why the attention function gave slightly different results... Although I did test this on CPU a while back and saw similar effects

KelianM commented 1 week ago

Added a PR, I think a good way to test this further might be to use a very high prediciton length - since if there are errors in kv_cache they will compound after each prediction. This will also make the speedup more visible since it scales with number of predictions.

kashif commented 1 week ago

also @KelianM during inference we do not really need the causal mask... is there a way to remove it?

KelianM commented 1 week ago

You can just use the previous method where att_mask=None, is_causal=False. If you look at the previous code though it was using causal attention even during inference. It was only when kv_cache was enabled that you used is_causal=False.

KelianM commented 1 week ago

You still need causal attention when working with the full sequence though. I think I have a more elegant solution than manually defining a mask. If you want to use the is_causal flag rather than a mask you need to use causal attention for the first step with kv_cache, since you are using the full sequence, and then only afterwards once kv_cache has been initialised and we are only working with the last token set is_causal to False. Will add it to the PR.

kashif commented 1 week ago

@KelianM you mean when fine-tuning? or what do you mean "when working with the full sequence" ?

yes please lets discuss on the PR

ashok-arjun commented 4 days ago

Hi @ashok-arjun,

On the slight differences in accuracy, I'm really not sure what is causing this. All I know is from my own debugging stepping through things the F.scaled_dot_product_attention seemed to be what was causing the variance even after my fixes. I think this will really require a deep dive to understand further.

I tried testing the original version using your zero-shot notebook, making sure to seed for deterministic results. Using the aus_retail dataset from the zero shot notebook, 30 samples, context length 32, batch size 4, seed 1234:

use_kv_cache=False: Inference duration: 13.91 seconds CRPS: 0.07899402871357544

use_kv_cache=True: Inference duration: 8.56 seconds CRPS: 0.08842500045081104

use_kv_cache=True with fixes: Inference duration: 10.92 seconds CRPS: 0.08004362005563928

Would appreciate if someone could try reproduce this and test more extensively. I will make a PR for my changes but won't be able to spend much longer on this. There's still an accuracy difference but it's much smaller, I'm convinced it's to do with the F.scaled_dot_product_attention though. The slight decrease in speed is also expected, since previously the cache wasn't resetting properly so the first step was skipping the full initialisation of the cache.

Thank you for posting the results! I'll investigate the issue once I get some time.