pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
84 stars 50 forks source link

Pathfinder gives confident wrong answer with small sample prediction #279

Open fonnesbeck opened 11 months ago

fonnesbeck commented 11 months ago

This example is taken from the baseball case study in pymc-examples. We fit a beta-binomial model to some baseball batting data:

data = pd.read_csv(pm.get_data("efron-morris-75-data.tsv"), sep="\t")

N = len(data)
player_names = data["FirstName"] + " " + data["LastName"]
# coords = {"player_names": player_names.tolist()}

with pm.Model() as baseball_model:
    at_bats = pm.MutableData("at_bats", data["At-Bats"].to_numpy())
    n_hits = pm.MutableData("n_hits", data["Hits"].to_numpy())
    baseball_model.add_coord("player_names", player_names, mutable=True)

    phi = pm.Uniform("phi", lower=0.0, upper=1.0)

    kappa_log = pm.Exponential("kappa_log", lam=1.5)
    kappa = pm.Deterministic("kappa", pm.math.exp(kappa_log))

    theta = pm.Beta("theta", alpha=phi * kappa, beta=(1.0 - phi) * kappa, dims="player_names")
    y = pm.Binomial("y", n=at_bats, p=theta, observed=n_hits, dims="player_names")

and then add a prediction for a fictional player that has zero hits in 4 appearances:

with baseball_model:
    theta_new = pm.Beta("theta_new", alpha=phi * kappa, beta=(1.0 - phi) * kappa)
    y_new = pm.Binomial("y_new", n=4, p=theta_new, observed=0)

What should occur (and does with either pymc.sample or pymc.fit) is that since the sample size of y_new is so small, it should be shrunk towards the population mean. Here is the population of players:

410871ea-80e5-44e6-a402-1c8acbd26ba6

and the population mean is given by phi:

112991f7-c433-4480-8026-e95aa154f926

however, the estimate for theta_new is way too large (larger than the most extreme player in the fitting dataset) with a high degree of posterior confidence:

adaf7323-a7f9-4a56-9a6b-1ae270ac6f0c

Running the same model with pm.fit or pm.sample returns more reasonable estimates just under the population mean.

Using PyMC 3.10.1 and pymc-experimental from the main repo.

junpenglao commented 11 months ago

Not sure, the pathfinder return result that underestimated kappa and theta: image

But probably this is the property of pathfinder, I dont work with it enough to provide good perspective. @ColCarroll has a bit more experience, maybe he has some idea?

ricardoV94 commented 11 months ago

See also the issues I found before with the 8 school example, where it would basically return the initval for whatever mu was: https://gist.github.com/ricardoV94/eafd20ac47d63525253b0a8adf5e5d76

junpenglao commented 11 months ago

yeah the pathfinder have a jaxopt dependency that have some convergent gap (compare to scipy.optimize.minimize). I think on the blackjax side we can be more explicit on detecting none convergence.

junpenglao commented 11 months ago

For the intermediate, I suggest adding some noise to the initial position: https://github.com/pymc-devs/pymc-experimental/blob/00d7a2b3cf3379e0a9420fb436667ab781e5a5e7/pymc_experimental/inference/pathfinder.py#L104, so at the very least we can run the pathfinder a couple of times.

ricardoV94 commented 11 months ago

You can use this to add jitter to RVs: https://github.com/pymc-devs/pymc/blob/0fd7b9e1d2208f1250b1c804bf5421013dba9023/pymc/initial_point.py#L111

aphc14 commented 2 months ago

After trying to find in the BlackJAX Pathfinder backend where the calculations causing the poor posterior estimates were coming from and not finding anything, I've decided to compare the Stan Pathfinder estimates with BlackJAX.

The comparison between the Stan and PyMC is in the two notebook links below (apologies for the untidy notebooks and coding, but the images would hopefully provide a good enough summary):

Eightschools Data https://gist.github.com/aphc14/a32d1f81b8993b8cc57867cd4466edbb

MLB Data (from above) https://gist.github.com/aphc14/9f38b2e45fd220ae4bf1eb6b967ca886

The most surprising outcome is Stan's version of Pathfinder also provides a poor estimate of the posterior on these two data sets. When initialising using jitter_rvs for PyMC + BlackJAX, the outputs are somewhat close to Stan outputs.

From these comparisons, is it safe to say there are no big issues with the backend calculations in BlackJAX Pathfinder?

junpenglao commented 2 months ago

Or both are wrong but in different places 😬 Is there any other implementation we can cross reference to?

fonnesbeck commented 2 months ago

The Pathfinder paper seems to show decent results for the centered model using multi-path (Fig 19).

aphc14 commented 2 months ago

There is another package, although its in Julia, that implements Pathfinder. But since the Pathfinder paper uses Stan, shouldn't we cross-check our results with Stan? I'll code up a comparison in a more streamlined fashion of other scenarios from posteriordb. I could try measuring the performance with the scaled 1−Wasserstein metric against the same or similar reference posterior to see if the PyMC resembles Stan performances. I'll get around to this probably after improving our PyMC implementation.