time-series-foundation-models / lag-llama

Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting
Apache License 2.0
1.08k stars 121 forks source link

Faster inference and more accurate point estimates with Lag-Lama #68

Closed KelianM closed 1 week ago

KelianM commented 4 weeks ago

Inference is really slow if you want multiple samples. After doing a little digging I found that the model is deterministic during inference (at least in my case it was, not sure if there's hyperparameters that could change this), meaning if you specify 20 samples it is generating the same dstribution parameters 20 times and randomly sampling each one once, which is very inefficient. I propose two fixes which allow you to use less samples for more accuracy and corresponingly much faster inference:

  1. A point estimate mode: Instead of taking a random sample from each distribution, take the mean. This is what I use in my fork with just one sample and I got much higher forecast accuracy than 30 samples with the previous approach.
  2. Multiple samples from each distribution: Instead of taking a single sample from each set of distribution parameters, take multiple. If the model is deterministic, you should actually just take all your random samples from one distribution.

Both of these have merit even if the model is not deterministic, for point estimates you should always be using the most likely value (mean) instead of a random sample. Point 2 is a little different if the model is non-deterministic, but could still be applied by averaging the distribution parameters of all parralel runs and then sampling, or at least taking more than one random sample per distribution.

ashok-arjun commented 3 weeks ago

Hi!

Thanks for the detailed description.

In our model, the uncertainty itself comes from the sampling of the Student's T distribution head at the end - you are right in that the parameters of the distribution are deterministic, but to cover the "range" of the distribution for probabilistic forecasting, our model relies on multiple samples from the said distribution. There are definitely better ways to construct probabilistic estimates, which can be found in many other models.

As for the fixes you propose:

  1. You could take just 1 sample from the distribution - but do keep in mind that you do not get any uncertainty estimate from the model in this case. This is the same if you take the "mean" of 30 samples, except that the "mean" would be slighty better. Our model was trained for probabilistic forecasting and we never tested the accuracy of point-forecasting using 1-sample or the mean; I would say it probably doesn't do well in that setup, but if it works for you, that's great.
  2. I do not completely understand this point. I would say this is what we do currently. We take K parallel samples from the distribution.

Would be happy to discuss/clarify more!:)

KelianM commented 3 weeks ago

Hi!

Thanks for the detailed description.

In our model, the uncertainty itself comes from the sampling of the Student's T distribution head at the end - you are right in that the parameters of the distribution are deterministic, but to cover the "range" of the distribution for probabilistic forecasting, our model relies on multiple samples from the said distribution. There are definitely better ways to construct probabilistic estimates, which can be found in many other models.

As for the fixes you propose:

  1. You could take just 1 sample from the distribution - but do keep in mind that you do not get any uncertainty estimate from the model in this case. This is the same if you take the "mean" of 30 samples, except that the "mean" would be slighty better. Our model was trained for probabilistic forecasting and we never tested the accuracy of point-forecasting using 1-sample or the mean; I would say it probably doesn't do well in that setup, but if it works for you, that's great.
  2. I do not completely understand this point. I would say this is what we do currently. We take K parallel samples from the distribution.

Would be happy to discuss/clarify more!:)

Hi,

Thanks for responding. On the two points:

  1. Just to clarify, for my "point estimate" mode I am not taking the mean of 1 random sample, I am fetching the mean of the underlying distribution produced by the distribution head with distr.mean. I manage to get great results on my dataset by doing this with only 1 parallel sample. Since the distribution parameters are deterministic (e.g. distr.mean returns 30 equal means for 30 parallel samples), increasing the number of samples does not improve accuracy of the point-estimate. To get the results I pretrained directly on point-estimates by using an absolute error loss function with distr.mean as my point estimate.

  2. I agree you do currently take K samples from the distribution. Where I thought it could improve is in the way you are doing this (by duplicating the input K times), which as I understand it leads to K parallel forward-passes of the network, all resulting in the same set of distribution parameters. I think you could do this far more efficiently with a single forward-pass where you save the distribution parameters (degree of freedom, mean and scale), then sample the saved distribution K times (i.e. don't duplicate input and instead use distr.sample(K) ).

Let me know if I've misunderstood. :)

ashok-arjun commented 2 weeks ago
  1. I see, thanks for clarifying. Yes, if you've trained directly on point estimates, that would be the case. I'm glad it worked out!
  2. Yes, you are correct. The implementation can be optimized a lot based on this. Thanks a ton for pointing this out! If you have this implementation, please feel free to make a PR. If not, I'll get this implemented soon-ish :)
KelianM commented 2 weeks ago

Hi @ashok-arjun,

PR submitted. As mentioned in the PR description I had to make the greedy assumption that the previous sample is the mean in order to keep a single forward-pass. I think this is the why the previous approach was used (to build a history of possible samples).

I think the mean is a fair assumption and forecast accuracy will actually likely improve (as it did in my case), but it will have some impact on uncertainty estimates to not be maintaining parallel paths of predictions. Even then I think the previous approach would really have to have large amounts of samples to truly benefit from the parallel paths - and for most people it is just not computationally feasible to be making so many forward passes.

Let me know what you think :)

KelianM commented 2 weeks ago

Hi @ashok-arjun,

PR submitted. As mentioned in the PR description I had to make the greedy assumption that the previous sample is the mean in order to keep a single forward-pass. I think this is the why the previous approach was used (to build a history of possible samples).

I think the mean is a fair assumption and forecast accuracy will actually likely improve (as it did in my case), but it will have some impact on uncertainty estimates to not be maintaining parallel paths of predictions. Even then I think the previous approach would really have to have large amounts of samples to truly benefit from the parallel paths - and for most people it is just not computationally feasible to be making so many forward passes.

Let me know what you think :)

Actually considering this I should probably make it a mode you can enable or something of the sort.

ashok-arjun commented 2 weeks ago

Thank you! I'll check out the PR.

Indeed assuming the mean is the previous forecast is not true "probabilistic forecasting". But I think it is useful anyway, as you said. The user can decide if they want to use this mode. Can you please make that a mode?

KelianM commented 2 weeks ago

Sure, you can now enable it with use_single_pass_sampling and it's turned off by default.