mitchelloharawild / distributional

Vectorised distributions for R
https://pkg.mitchelloharawild.com/distributional
GNU General Public License v3.0
94 stars 15 forks source link

Truncated distribution performance is very bad #49

Closed Fuco1 closed 3 years ago

Fuco1 commented 3 years ago

mean(dist_truncated(dist_normal(2, 5), 1, 2)) takes a noticeable time to compute. I don't have a precise number but just by the fact I can notice the lag it must be somewhere on the order of 100ms. This makes truncated distributions practically unusable with fable where forecast produces a .mean column (taking the calculation from 0 to 1 second, over 200k SKUs will make my forecasts run 55 hours).

I understand from the code it is now written in a generic way. Could there be some more optimized variants for specific distributions, like the normal distribution in particular?

Furthermore, it seems the mean value returned varies between calls. Is this done using a simulation? I suppose closed-form formulas are really difficult to come by.

Suppose I don't care about any "nice" statistical properties and I want to simply clamp any output (mean, quantile, median...) between lower and upper (for example if the wrapped distribution returns mean = 3 and my lower is 4, I simply return 4). I can do this easily on the resulting tibble, but I'd like to bake this into a distribution so that I can use it seamlessly with the fable framework. How would I go about implementing something like this as a distribution?

mitchelloharawild commented 3 years ago

I understand from the code it is now written in a generic way. Could there be some more optimized variants for specific distributions, like the normal distribution in particular?

Yes, this is possible and should be added. I know it is possible to quickly compute the mean of a truncated Normal distribution, which would speed up your example (and most applications of truncated distributions in forecasting). A faster method for truncated distribution means of any distribution would be much nicer though. It is also likely possible to implement a better/faster method for computing a distribution's mean in general (by querying the density instead of simulating values).

Furthermore, it seems the mean value returned varies between calls. Is this done using a simulation? I suppose closed-form formulas are really difficult to come by.

Correct, I haven't added a method for the truncated distribution's mean yet and so it is defaulting to a simulation and computing the sample mean.

Suppose I don't care about any "nice" statistical properties and I want to simply clamp any output (mean, quantile, median...) between lower and upper (for example if the wrapped distribution returns mean = 3 and my lower is 4, I simply return 4). I can do this easily on the resulting tibble, but I'd like to bake this into a distribution so that I can use it seamlessly with the fable framework. How would I go about implementing something like this as a distribution?

"Clamping" a distribution as you describe would produce a dist_degenerate() distribution for it to satisfy your description.

library(distributional)
mean(dist_normal(3))
#> [1] 3
mean(dist_truncated(dist_normal(3), lower = 4))
#> [1] 4.508183
mean(dist_degenerate(4))
#> [1] 4

Created on 2020-09-11 by the reprex package (v0.3.0)

mitchelloharawild commented 3 years ago

I've added a shortcut for truncated distribution means if the underlying distribution is Normal or samples.

# Before
library(distributional)
bench::mark(mean(dist_truncated(dist_normal(2, 5), 1, 2)))
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 1 x 6
#>   expression                                       min median `itr/sec`
#>   <bch:expr>                                    <bch:> <bch:>     <dbl>
#> 1 mean(dist_truncated(dist_normal(2, 5), 1, 2)) 90.3ms 95.5ms      10.4
#> # … with 2 more variables: mem_alloc <bch:byt>, `gc/sec` <dbl>

# After
devtools::load_all("~/github/distributional/")
#> Loading distributional
bench::mark(mean(dist_truncated(dist_normal(2, 5), 1, 2)))
#> # A tibble: 1 x 6
#>   expression                                      min median `itr/sec` mem_alloc
#>   <bch:expr>                                    <bch> <bch:>     <dbl> <bch:byt>
#> 1 mean(dist_truncated(dist_normal(2, 5), 1, 2)) 336µs  362µs     2582.        NA
#> # … with 1 more variable: `gc/sec` <dbl>

Created on 2020-10-06 by the reprex package (v0.3.0)

Fuco1 commented 3 years ago

How do you do this? Is there a formula you solved for and now can provide a fast calculation? What would it take to make this work for negative-binomial or poisson distributions?

mitchelloharawild commented 3 years ago

The mean of a truncated distribution is much like how a mean is computed from any distribution. A continuous distribution on the domain (-Inf,Inf) has a mean of where f(x) is the density. For a truncated distribution, this integral is now on the domain (a,b) which can be solved with some math .

A faster fall back than sampling for the mean (instead of NextMethod() here) would be to numerically approximate the integral.

https://github.com/mitchelloharawild/distributional/blob/e996ba403730d27bf2a7c25274ba937c90fbd623/R/truncated.R#L91-L104

Fuco1 commented 3 years ago

Makes sense. The integral is in a way a weighted sum by probability of each particular outcome (I'm glossing over the details). So in principle, so long as we have the explicit formulas for density and cumulative function we can use this trick.

So far I'm really happy with the change and the normal distribution is enough. I suppose we can address the rest on a case-by-case basis if the need arises for someone. Thanks for the fix!