Tradeshift / blayze

A fast and flexible Naive Bayes implementation for the JVM
MIT License
19 stars 11 forks source link

Bayesian naive bayes #19

Closed rasmusbergpalm closed 5 years ago

rasmusbergpalm commented 5 years ago

Upgrade to bayesian naive bayes.

Bayesian naive bayes differ from naive bayes by treating the parameters of the estimated distributions as random variables (e.g. unknown), and integrating them out. Previously we just used maximum likelihood estimates (MLE). The main effect is that we're less certain about our predictions when we've seen small numbers of samples. As the number of samples grow, the distributions approaches their MLE counterparts.

For instance if we observe the samples x=[-1.0, 1.0], then the MLE estimate is a gaussian with mean=0, variance = 1. This is the gaussian from which the data is most likely to be observed. But the data could also have come from mean=1, variance=1. It's less likely but it's possible. In fact there's an infinite amount of gaussians that the data could have come from. Being bayesian about it we put priors over all the gaussians and then integrate them out. If T represents the parameters of the distributions, and D the observed data then:

p(outcome | inputs, D) ~ ∫ p(inputs | outcome, T) p(outcome | T) p(T | D) dT

using p(T | D) ~ p(D | T) P(T) (dropping the constant p(D))

p(outcome | inputs, D) ~ ∫ p(inputs | outcome, T) p(outcome | T) p(D | T) P(T) dT

By selecting mathematically convenient priors (aka. conjugate priors), the integrals can be done analytically. Luckily for most distributions someone has already done this. See https://en.wikipedia.org/wiki/Conjugate_prior#Table_of_conjugate_distributions

The quantity we're intersted in is the last column, aka. posterior predictive. This is equal to p(inputs | D), i.e. after the parameters have been integrated out.

So to implement new features we just need to lookup the posterior predictive and replace the log likelihood with this. I did that for all the current features and the outcome prior.

Benchmark comparison, log scale, s/ops, lower is better. Very similar, with one outlier. Error is +- 5times the value for the outlier on 3.1. For master (3.0) the error is +- 3times the value. So hard to say anything conclusive. A slight drop in performance is expected since the log gamma function is somewhat expensive.

image