fasiha / ebisu

Public-domain Python library for flashcard quiz scheduling using Bayesian statistics. (JavaScript, Java, Dart, and other ports available!)
https://fasiha.github.io/ebisu
The Unlicense
313 stars 32 forks source link

Look at accuracy of Beta-posterior when rebalancing? #18

Closed fasiha closed 1 year ago

fasiha commented 4 years ago

With version 1.0, we rebalance the model to roughly near the half-life so a and b aren't too different.

But are there some t's that yield Beta distributions that are more faithful to the GB1 posterior than others?

When we updateRecall, ought we spend a bit more time finding the t' whose final-Beta's higher moments are the least different from the GB1 posterior's higher moments?

If so, that'll also impact #17.

fasiha commented 3 years ago

Here's a script that addresses this question, which is definitely worth considering. The script isn't documented and takes a couple of minutes to run but hopefully it will do since my goal is to explain the findings briefly—

In Ebisu, we have the exact analytical posterior probability distribution on recall at any time in the future. That is, for a model (alpha, beta, t) that specifies that the recall at time t will follow a Beta(alpha, beta) distribution, if we have a quiz at some tnow time, we have this exact analytical posterior on recall at any tback time.

We eventually want to collapse this complicated distribution to a Beta random variable for storage—that Beta is obviously an approximation to the true distribution. Right now we are somewhat cavalier about choosing what time to tback to use to collapse to the Beta. There's a coarse "rebalancing" step where I find the approximate Beta for tback=t, find that model's halflife coarsely within a factor of ~2x, and then rerun the approximation from true posterior to a Beta at that approximate halflife. We've talked about always exact-rebalancing in #31, i.e., pick tback to be the exact halflife of the posterior (the time at which the posterior probability is 0.5).

But what tback is the best in terms of the least approximation error between the true posterior at tback and the best-fit Beta at tback?

The Kullback-Leibler (KL) metric measures how "bad" an approximation some probability distribution is from another "true" distribution. Its minimum is zero for the same distribution.

The script below implements the analytical exact posterior for any tback (there's a couple of prints that show it's correct).

Then, for model (alpha=3, beta=4, t=10), which has halflife of 8.01 time units, we have the following:

Screen Shot 2021-04-11 at 00 42 45

Above x axis: tback between 3 and 50 time units; y axis: KL measure of how bad the Beta at tback is given the true posterior at the same tback for all four combinations of (success, fail) and (quiz at 5 time units, quiz at 30 time units). Recall halflife is 8 time units.

Observations: for all four of these, there definitely is ONE tback where the KL divergence seems to drop to zero, or at least some orders of magnitude below nearby tbacks. Looking at this tells me that we should be more thoughtful about where in time we choose to fit our posterior to.

So then, for this model (3, 4, 10), and for a few tnows, what is the tback that produces the lowest KL divergence between true posterior and Beta fit? (Recall tnow is the time that the quiz happened.)

quiz quiz time updated halflife tback with lowest KL lowest KL
True 1 8.211 10.000 -9.860e-15
True 5 9.010 10.000 -1.441e-15
True 15 10.981 10.000 -9.297e-16
True 30 13.896 10.000 -8.849e-16
False 1 6.285 8.408 1.427e-07
False 5 6.564 8.884 2.888e-08
False 15 7.066 11.260 4.917e-08
False 30 7.490 14.029 8.415e-06

Iiiiiiinteresting? For successful quizzes, the best tback is t, i.e., don't move the time between the initial model and its update. Furthermore, the KL divergence here is 0 (to machine precision)—the true posterior is a Beta. If we rebalance to use tback equal to the posterior's halflife, we'll be throwing away an exact update (KL=0) for an approximation (KL>0).

But for failed quizzes, what's this? The best tback is greater than the prior's halflife (8 time units), and can be significantly more? This made me curious if there was maybe some simple relationship between tnow and the best tback:

tback

Above For all three sub-plots, the x axis is tnow. Y axes described in plot. Bottom plot: I calculate the KL divergence using scipy.integrate.quad numerical integration, which yields not just the KL divergence (middle plot) but also the error associated with that calculation. The error is at least two orders of magnitude lower than the KL divergence's estimate, giving us confidence that we're not looking at numerical issues here.

So the relationship between best tback for varying quiz times—when the quiz resulted in failure—doesn't seem straightforward. I'm surprised it's non-monotonic, but I could very easily be running into numerical instability and roundoff with my naive implementation of the posterior.

A lot of food for thought.

from ebisu import modelToPercentileDecay, updateRecall
import numpy as np
import scipy.integrate as integ
from scipy.stats import beta
from scipy.special import beta as betafn
from scipy.optimize import minimize_scalar
import pylab as plt
plt.ion()

def makePost(model, tnow, tback, c, d):
  alpha, beta, t = model
  dt = tnow / t
  et = tback / tnow
  B = betafn
  den = c * dt * et * B(alpha + dt, beta) + d * dt * et * B(alpha, beta)

  def post(p):
    first = (1 - p**(1 / (dt * et)))**(beta - 1)
    secondExp = (alpha + dt) / (dt * et) - 1
    second = c * p**secondExp
    thirdExp = alpha / (dt * et) - 1
    third = d * p**thirdExp
    num = first * (second + third)
    return num / den

  return post

def validatePost(x=True):
  pre = (3., 4., 10.)
  c, d = (1, 0) if x else (-1, 1)
  tnow = 5
  tback = 8
  post = makePost(pre, tnow, tback, c, d)

  print('should be 1', integ.quad(post, 0, 1))

  def moments(maxN):
    alpha, beta, t = pre
    dt = tnow / t
    et = tback / tnow

    nums = [
        c * betafn(alpha + dt + N * dt * et, beta) + d * betafn(alpha + N * dt * et, beta)
        for N in range(1, 1 + maxN)
    ]
    den = c * betafn(alpha + dt, beta) + d * betafn(alpha, beta)
    return [num / den for num in nums]

  moms = moments(5)
  qmoms = [integ.quad(lambda p: p**n * post(p), 0, 1) for n in range(1, 1 + len(moms))]
  print(list(zip(moms, qmoms)))

validatePost(True)
validatePost(False)

def kl(tback, pre, x, tnow):
  c, d = (1, 0) if x else (-1, 1)
  post = makePost(pre, tnow, tback, c, d)
  updated = updateRecall(pre, 1 if x else 0, 1, tnow, tback=tback, rebalance=False)
  approx = lambda p: beta.pdf(p, updated[0], updated[1])
  divergence = integ.quad(lambda p: post(p) * np.log(post(p) / approx(p)), 0, 1, limit=500)
  return divergence

def kltest(x=True, tnow=5.0):
  pre = (3., 4., 10.)
  tbacks = np.logspace(np.log10(3), np.log10(50))
  kls, errs = np.vectorize(lambda foo: kl(foo, pre, x, tnow))(tbacks)
  return tbacks, kls, errs, x, tnow

for x in [True, False]:
  for tnow in [5., 30.]:
    tbacks, kls, errs, x, tnow = kltest(x, tnow)
    plt.figure()
    plt.loglog(tbacks, kls, '.-')
    plt.title(f"{x}, tnow={tnow:0.2f}")

pre = (3., 4., 10.)
print('pre', modelToPercentileDecay((3., 4., 10.)))
for x in [True, False]:
  for tnow in [1., 5., 15., 30.]:
    updated = updateRecall(pre, 1 if x else 0, 1, tnow)
    minimized = minimize_scalar(
        lambda tback: kl(tback, pre, x, tnow)[0], bounds=[3., 50.], method='bounded')
    print(
        dict(
            x=x,
            tnow=tnow,
            hl=modelToPercentileDecay(updated),
            tbackMin=minimized.x,
            klmin=minimized.fun))

tnows = np.linspace(0.1, 100, 50)
bestTbacks = np.vectorize(lambda tnow: minimize_scalar(
    lambda tback: kl(tback, pre, False, tnow)[0], bounds=[1., 60.], method='bounded').x)(
        tnows)
quads = [kl(tback, pre, False, tnow) for (tnow, tback) in zip(tnows, bestTbacks)]

fig, axs = plt.subplots(3)
axs[0].plot(tnows, bestTbacks, '.-')
axs[0].set_ylabel('best tback')

axs[1].semilogy(tnows, np.array([min for min, err in quads]))
axs[1].set_ylabel('min KL')

axs[2].semilogy(tnows, np.array([err for min, err in quads]))
axs[2].set_ylabel('error')
axs[2].set_xlabel('tnow')

fig.set_tight_layout(True)
[a.grid() for a in axs]
fasiha commented 3 years ago

So the work above led to a nice analytical result showing the exact posterior for a series of True/False quizzes: https://fasiha.github.io/ebisu/#appendix-exact-ebisu-posteriors

Screen Shot 2021-04-16 at 09 46 18

For now, with 2.1.0, I've chosen to rebalance to the new halflife despite the findings above (and despite the computational burden of rebalancing). The difference in KL divergence between the “best” tback and the halflife, at least in the examples above, seem minor: 1e-3 nats or less, and we get the benefits of always rebalancing: not having to calculate a halflife ever, potentially faster approximate-predictRecall, cleaner code.

(Also, I'm not sure how long Ebisu will continue to use this Beta-on-recall model given my efforts in creating a new model per #43.)

fasiha commented 1 year ago

Closed this because Ebisu v3 is moving away from Beta priors on recall to Gamma priors on half-life.