tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

Could higher level inference/fitting functions be part of TFP? #1161

Closed jeffpollock9 closed 3 years ago

jeffpollock9 commented 3 years ago

Hi,

Most of the time when I see comparisons of TFP to other probabilistic frameworks, one of the downsides to TFP is how hard it can be to write code to do inference (see for example https://colcarroll.github.io/ppl-api).

I am wondering if there is any scope to add a higher level API to TFP which wraps up some of the common patterns and provides sensible defaults? There is already the fit_with_hmc function for structural time series models which does this and I think it's really useful.

If a user codes up some sort of JointDistribution then I think it encodes almost everything you need to call a simple function do_inference(joint_distribution) which behaves sensibly by default. I tried a small example with MAP estimation for a simple linear regression and it seems to work. I've added the full code as a gist but the main parts are:

@tfd.JointDistributionCoroutine
def joint_dist():
    intercept = yield Root(tfd.Normal(loc=0.0, scale=1.0, name="intercept"))
    coefficients = yield Root(
        tfd.MultivariateNormalDiag(
            loc=0.0, scale_diag=tf.ones([NUM_COEFFICIENTS]), name="coefficients"
        )
    )
    scale = yield Root(tfd.Exponential(rate=1.0, name="scale"))
    loc = intercept[..., tf.newaxis] + tfl.matvec(design_matrix, coefficients)
    yield tfd.MultivariateNormalLinearOperator(
        loc=loc,
        scale=tfl.LinearOperatorScaledIdentity(
            num_rows=NUM_OBSERVATIONS, multiplier=scale
        ),
        name="observations",
    )

estimates, opt = find_map(
    joint_distribution=joint_dist,
    condition_on={"observations": observations},
    batch_shape=[10],
)

best_estimate = {k: v[tf.argmin(opt.objective_value)] for k, v in estimates.items()}

truths = [random_intercept, random_coefficients, random_scale]

for (name, estimate), truth in zip(best_estimate.items(), truths):
    print(f"{name}:")
    print(f"\testimate: {estimate.numpy()}")
    print(f"\ttruth:    {truth.numpy()}")

# intercept:
#   estimate: 0.22806133329868317
#   truth:    0.32746851444244385
# coefficients:
#   estimate: [ 0.15262237 -0.8363303   0.31761777  0.10883351 -0.46800813]
#   truth:    [ 0.08422458 -0.86090374  0.37812304 -0.00519627 -0.49453196]
# scale:
#   estimate: 2.6733169555664062
#   truth:    2.6272153854370117

I added the batch_shape argument to see if I could run many optimizations from different starting points and then just pick out the "best" one.

I guess the main functions to add would probably be run_nuts(joint_distribution, ...) or similar.

I'd be happy to spend some time on this, if at all useful.

Thanks again.

ColCarroll commented 3 years ago

Hi Jeff! I wrote the linked blogpost, and have also been thinking a lot about this sort of thing. I think a few things have changed since writing that (and you picked up on these) are just about ready to be used to put together high level functions.

Specifically, you have used JointDistribution* as a handy "frontend" for writing models, and then experimental_pin and experimental_default_event_space_bijector to condition and unconstrain parameters. I am excited for these second two to come out of "experimental". A good run_nuts function will at least require an analogue to tfp.mcmc.experimental.PreconditionedHMC to allow for more efficient sampling by using a non-identity mass matrix for the momentum, and a way to automatically select that mass matrix (this is in progress).

To get more concrete, this is a nice find_MAP function! There is an effort to port Statistical Rethinking to TFP, and I think you might be able to really improve the quap function there (plus you could mostly rely on existing tests). What do you think?

jeffpollock9 commented 3 years ago

Hi Colin, thanks for your reply and for your very helpful blog!

Great to hear that this is already in progress.

I'll take a look at quap next week and try to send over some ideas - cheers!

ColCarroll commented 3 years ago

Closing this to keep issues tidy, but feel free to ping here or in a new issue!