tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

Feature Request - Internal MCMC Variable Traces #268

Closed alexcordover closed 4 years ago

alexcordover commented 5 years ago

Posted in stack overflow, but no answers there, so it might be a feature request.

I have the basic model:

flips2 = sum(np.round(np.random.rand(10)))

def m():
    prior1 = ed.Beta(concentration0=10, concentration1=10, name='prior1')
    outcome1 = ed.Binomial(total_count=10, probs=prior1, name='outs1')

    prior2 = ed.Beta(concentration0=1, concentration1=1, name='prior2')
    outcome2 = ed.Binomial(total_count=10, probs=prior2, name='outs2')

    c = outcome1 + outcome2
    return c

fnc = ed.make_log_joint_fn(m)

def target(p1, p2):
    return fnc(prior1=p1, outs1=flips1, prior2=p2, outs2=flips2)

step_size = tf.get_variable(
    name='step_size',
    initializer=.001,
    use_resource=True,  # For TFE compatibility.
    trainable=False)

hmc = tfp.mcmc.HamiltonianMonteCarlo(
    target_log_prob_fn=target,
    num_leapfrog_steps=3,
    step_size=step_size,
    step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy())

samples, kernel_results = tfp.mcmc.sample_chain(
    num_results=int(10e3),
    num_burnin_steps=int(1e3),
    current_state=[.01, .01],
    kernel=hmc)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    samples_, res = sess.run([samples, kernel_results])

From here, samples_ gives me the Markov chain samples for prior1 and prior2. However, I am additionally interested in seeing the values for other variables inside the model. For example, I would like to see the values that outcome1 and outcome2 take on.

Most importantly though, I would like to see the value of c without parametrizing it. In pyMC3, this would be accomplished by making it a pm.Deterministic. I have not been able to find the equivalent functionality in TF Probability - is it available?

I suppose my immediate guess would be to do the MCMC without the use of sample_chain, get the tensors of interest with e.g. tf.get_default_graph().get_tensor_by_name(...), and run them all in the sess.run(...) call, but it would be nice to directly get the results of interest from sample_chain.

csuter commented 5 years ago

Sorry we overlooked the stackoverflow question (I thought I had tag-watching set up for tfp but I didn't; I do now!).

I'm a bit confused by the question though. In this case, by currying outcome1 and outcome2 with flips1 and flips2, you've set things up so that outcome{1,2} are not being sampled by the markov chain -- only values for the latents prior{1,2} are being sampled from the posterior. The values of outcome{1,2} are fixed. During the markov chain sampling process, no new values are being computed for the outcome variables, so there are no such internal values to trace.

alexcordover commented 5 years ago

Sorry, I was trying to set up a more basic example than the model I was intending to run and ended up not asking a very good question.

The use case is a basic basketball model defined as:

def model():
    # attempted fgs
    hyper_fg_mu = ed.Normal(loc=85, scale=10, name='hyper_fg_mu')
    hyper_fg_sigma = ed.Uniform(low=1, high=20, name='hyper_fg_sigma')
    fg_dist = ed.Normal(loc=hyper_fg_mu, scale=hyper_fg_sigma, name='fg_attempts', sample_shape=sample_shape)

    # attempted threes
    attempted_three_pct = ed.Uniform(low=0, high=1, name='three_attempts', sample_shape=sample_shape)
    threes = fg_dist * attempted_three_pct
    twos = fg_dist * (1 - attempted_three_pct)

    # fg points
    made_two_pt = ed.Uniform(low=0, high=1, name='two_pct', sample_shape=sample_shape)
    made_three_pt = ed.Uniform(low=0, high=1, name='three_pct', sample_shape=sample_shape)
    three_pts = 3 * threes * made_three_pt
    two_pts = 2 * twos * made_two_pt

    points = three_pts + two_pts

    return points

I have data for each of fg_dist, attempted_three_pct, made_two_pt, and made_three_pt. So, I am able to sample from hyper_fg_mu and hyper_fg_sigma with no issues. But, I am also interested in finding out the distribution of points = three_pts + two_pts and potentially three_pts and two_pts individually. I've built a similar model previously in pyMC which has the pm.Deterministic wrapper/variable which is able to return the predicted distributions over each of the internal variables of interest (points, three_pts etc).

My thought on how to accomplish this was to get the estimated parameters of the sampled variables (hyper_fg_mu and hyper_fg_sigma here), build out the model in tfp, and just sample it some number of times, but I didn't know if similar functionality was available here as compared to pyMC.

davmre commented 5 years ago

Hi Alex, it's possible that ed.Deterministic and ed.tape might help you. You can wrap any computed value as a deterministic random variable, e.g., points = ed.Deterministic(three_pts + two_pts, name='points'); this makes it visible to traces and other methods that manipulate the model.

Running a model inside of ed.tape() builds a trace of all random variables. So

with ed.tape() as prior_sample:
  model()

would construct an OrderedDict prior_sample mapping RV names to objects including the values they took in that model execution (a prior sample by default). Then you can use an interceptor to set the parent values so that you get values conditioned on posterior samples (what I think you want here?). For example:

with ed.tape() as posterior_predictive_sample:
  with ed.interception(make_value_setter(
    hyper_mg_mu=hyper_mg_mu_sample,
    hyper_mg_sigma=hyper_mg_sigma_sample)):
      model()

(here make_value_setter is a simple interceptor defined several times across the Edward2 examples, e.g., https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/deep_exponential_family.py#L154 -- you'd need to copy it over, though we're working to get this as a built-in soon)

would build an OrderedDict posterior_predictive_sample in which hyper_mg_mu and hyper_mg_sigma are set to the passed-in values and all downstream variables (including points) are sampled conditionally on those values.

Does any of this seem relevant? Unfortunately there's not currently a nice way to sample from an Edward2 model in parallel over a batch of MCMC posterior samples, beyond using something like tf.map_fn (or crossing your fingers and hoping broadcasting just 'does the right thing' when you pass in a batch tensor), so you might get more speed from a custom approach, but 'set model variables with an interceptor and re-simulate from the model' is IMHO at least a nice, conceptually clean set of ideas to have in your toolbox.

alexcordover commented 5 years ago

@davmre thanks for the response. Actually, this is exactly what I'm looking for. Additionally, the 8 schools jupyter notebook has some relevant information for the procedure which was helpful.

At the risk of going too much into usage here, I just have one more question.

Let's say I perform inference on hyper_fg_mu and hyper_fg_sigma, and I have 2 Markov chains that represent samples from the posterior of the variables. Is there a recommended way of sampling directly from this posterior distribution for use in the interceptor?

with ed.tape() as posterior_predictive_sample:
  with ed.interception(make_value_setter(
    hyper_mg_mu=SOME_SAMPLE_FROM_THE_POSTERIOR,
    hyper_mg_sigma=SOME_SAMPLE_FROM_THE_POSTERIOR)):
      model()

The original Edward had the Empirical random variable which (while I never used the original Edward) I assume sampling from it would return values from the Markov chain. I could randomly choose a value from the posterior chain for the interceptor, but if Edward2 had the functionality already available, I'd like to use it. I did see that #183 had discussed this but was curious if there was an alternative in the mean time.

davmre commented 5 years ago

I could randomly choose a value from the posterior chain for the interceptor, but if Edward2 had the functionality already available, I'd like to use it.

I don't think we have a built-in solution at the moment. Note that you'd have to be a little bit careful using something like ed.Empirical (when we check it in), because you want to draw samples from the joint posterior over hyper_fg_mu and hyper_fg_sigma -- this means you'd need to share the same random index across both Tensors. You wouldn't get that automatically from wrapping them as separate Empiricals, though you could probably adapt code from the Empirical implementation.

jvdillon commented 5 years ago

As a stop gap measure, you could write the log prob and sample separately. While this means youre essentially specifying the model twice, the upside is that it can sometimes be easier to see how things fit together.

rv_hyper_fg_mu = ed.Normal(loc=85, scale=10)
rv_hyper_fg_sigma = ed.Uniform(low=1, high=20)
rv_percent = tfd.Uniform(low=0, high=1)

# https://github.com/tensorflow/probability/blob/master/discussion/joint_log_prob.md

def joint_log_prob(
                   hyper_fg_mu,
                   hyper_fg_sigma,
                   field_goals,
                   attempt_threes_pct,
                   made_three_pct,
                   made_two_pct):
  rv_field_goals = tfd.Normal(loc=hyper_fg_mu, scale=hyper_fg_sigma)
  return sum([
      rv_hyper_fg_mu.log_prob(hyper_fg_mu),
      rv_hyper_fg_sigma.log_prob(hyper_fg_sigma),
      rv_field_goals.log_prob(field_goals),
      rv_percent.log_prob(attempt_threes_pct), # == 0
      rv_percent.log_prob(made_three_pct),     # == 0
      rv_percent.log_prob(made_two_pct),       # == 0
  ])

condition_on = [field_goals, attempt_threes_pct, made_three_pct, made_two_pct]
unnorm_log_posterior = lambda *hypers: joint_log_prob_and_points(*hypers, *condition_on)

# Do MCMC here. Dont forget to use TransformedTransitionKernel
# with Sigmoid bijectors for pcts and Identity, Chain([AffineScalar(20, 1), Sigmoid()]), for hypers.

def sample(hyper_fg_mu, hyper_fg_sigma):
  batch_shape = tf.shape(hyper_fg_mu)

  field_goals = tfd.Normal(loc=hyper_fg_mu, scale=hyper_fg_sigma).sample()

  attempt_threes_pct = rv_percent.sample(batch_shape)
  made_three_pct = rv_percent.sample(batch_shape)
  made_two_pct = rv_percent.sample(batch_shape)

  threes = field_goals * attempt_threes_pct
  twos = field_goals * (1. - attempt_threes_pct)
  three_pts = 3. * threes * made_three_pct
  two_pts = 2. * twos * made_two_pct

  return [
      [three_pts, two_pts],  # Deterministic.
      [field_goals, attempt_threes_pct, made_three_pct, made_two_pct], # Random
  ]
srvasude commented 4 years ago

Closing this as Edward2 has moved to https://github.com/google/edward2.

If using JointDistribution flavors, then I believe in sample_chain you can specify whatever trace_fn you want to trace through other things. It'll be a more condensed version of Josh's answer above.