tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Correlation matrix [cholesky] bijector #400

Closed brianwa84 closed 5 years ago

brianwa84 commented 5 years ago

As discussed on https://groups.google.com/a/tensorflow.org/d/msg/tfprobability/JYNa3_g33qo/asqjrRs0BAAJ it would be nice to have a bijector to go from unconstrained vectors to LKJ distributed correlation matrices. Some guessing that our existing LKJ might already have the right forward transformation implemented, but would need to add the inverse and the log det of deformation.

bloops commented 5 years ago

I have some ideas on how to achieve this using a UnitNorm and TransformAxis bijector, if this is still up for grabs?

skeydan commented 5 years ago

@bloops I'm happy if you do it, so I can write about it earlier :-) Thanks!

skeydan commented 5 years ago

Hi @bloops can you please let me know as soon as there's anything I can test (even if it's still a bit of work in progress)? I'm very curious to test with my dataset :-)

bloops commented 5 years ago

Sure, I will keep you updated. Planning to take a stab at this this weekend. Thanks!

On Fri, May 10, 2019, 5:53 AM Sigrid Keydana notifications@github.com wrote:

Hi @bloops https://github.com/bloops can you please let me know as soon as there's anything I can test (even if it's still a bit work in progress)? I'm very curious to test with my dataset .-)

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/400#issuecomment-491278319, or mute the thread https://github.com/notifications/unsubscribe-auth/AAG2DZAZS6ICAPX2QH7XLVTPUVV5JANCNFSM4HLVH2FQ .

skeydan commented 5 years ago

great, thanks :-)

bloops commented 5 years ago

Hi @skeydan, I have a prototype for this now. This might have bugs and the performance has a lot of room for improvement. So far, some initial testing looks promising. :)

The LKJ bijector is combination of two bijectors. The first one maps k unconstrained reals to a k + 1 unit vector with the last coordinate strictly positive. I called it the HalfSphere bijector. The second one, TransformRowTriL, takes n * (n - 1) unconstrained reals and maps it to the strictly lower part of an n x n matrix, and then applies the HalfSphere bijector to each row. The image of the bijector is exactly the set of n x n correlation matrices in Cholesky space. This image is also the support of LKJ distribution with input_output_cholesky=True.

New Bijectors

HalfSphere bijector

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import tensorshape_util

class HalfSphere(tfb.Bijector):
  def __init__(self, validate_args=False, name="halfsphere"):
    super(HalfSphere, self).__init__(
        validate_args=validate_args, 
        forward_min_event_ndims=1, 
        name=name)

  def _forward_event_shape(self, input_shape):
    if not input_shape[-1:].is_fully_defined():
      return input_shape
    return input_shape[:-1].concatenate(input_shape[-1] + 1)

  def _forward_event_shape_tensor(self, input_shape):
    return tf.concat([input_shape[:-1], [input_shape[-1] + 1]], axis=0)

  def _inverse_event_shape(self, output_shape):
    if not output_shape[-1:].is_fully_defined():
      return output_shape
    if output_shape[-1] <= 1:
      raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1])
    return output_shape[:-1].concatenate(output_shape[-1] - 1)

  def _inverse_event_shape_tensor(self, output_shape):
    if self.validate_args:
      # It is not possible for a negative shape so we need only check <= 1.
      is_greater_one = assert_util.assert_greater(
          output_shape[-1], 1, message="Need last dimension greater than 1.")
      output_shape = distribution_util.with_dependencies(
          [is_greater_one], output_shape)
    return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)

  def _forward(self, x):
    # Pad 1 to the back and divide by the norm. This transformation corresponds
    # a bijection between R^{n} and the positive half of the n-sphere in R^{n+1}.
    # Particularly, the subset of the n-sphere (x_1, x_2, ... x_n, x_{n+1})
    # satisfying x_{n+1} > 1.    
    y = distribution_util.pad(x, value=1., axis=-1, back=True)

    # Set shape hints.
    if tensorshape_util.rank(x.shape) is not None:
      last_dim = tf.compat.dimension_value(x.shape[-1])
      shape = tensorshape_util.concatenate(
          x.shape[:-1],
          None if last_dim is None else last_dim + 1)
      tensorshape_util.set_shape(y, shape)

    return y / tf.norm(y, axis=-1)[..., tf.newaxis]

  def _inverse(self, y):
    # The last coordinate is the reciprocal of the norm. Just multiply by the 
    # norm and truncate to first n coordinates to recover the inverse.
    return y[..., :-1] / y[..., -1:]

  def _inverse_log_det_jacobian(self, y):
    return -self._forward_log_det_jacobian(self._inverse(y))

  def _forward_log_det_jacobian(self, x):
    # The map defines an embedding from R^{n} to an n-dimensional manifold in
    # R^{n + 1}. The change in volume is no longer given by det(J); since it
    # is not a diffeomorphism between open subsets of R^{n}.
    # The change in n-dimensional volume (or surface area) is instead given by
    # sqrt(|det J^T J|); where J is the (n + 1) x n Jacobian matrix.
    # 
    # To aid calculation, observe that by considerations of symmetry, the volume
    # element can only depend on norm(x); w.l.o.g. we can assume that x_1 = r
    # and x_2, ..., x_n = 0.
    # 
    # https://en.wikipedia.org/wiki/Volume_element#Volume_element_of_manifolds

    n = tf.cast(distribution_util.prefer_static_shape(x), dtype=x.dtype)[-1]
    r_squared = tf.reduce_sum(tf.square(x), axis=-1)
    return -0.5 * (n + 1) * tf.math.log1p(r_squared)

TransformRowTriL

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import tensorshape_util

def _make_shape(batch_shape, event_dim):
  return tf.concat([batch_shape, [event_dim]], axis=-1)

def _get_last_dimension(x):
  x = tf.convert_to_tensor(value=x, name="x")
  m = tf.compat.dimension_value(
      tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
  if m is not None:
    return m
  return tf.shape(input=x)[-1]

class TransformRowTriL(tfb.Bijector):
  def __init__(self, row_bijector, validate_args=False, name="transform_row"):
    super(TransformRowTriL, self).__init__(
        validate_args=validate_args, 
        forward_min_event_ndims=1,
        inverse_min_event_ndims=2,
        name=name)
    self._row_bijector = row_bijector

  def _get_triangular_n(self, x):
    x = tf.convert_to_tensor(value=x, name="x")
    m = tf.compat.dimension_value(
        tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
    if m is not None:
      # Formula derived by solving for n: m = n(n+1)/2.
      m = np.int32(m)
      n = np.sqrt(0.25 + 2. * m) - 0.5
      if n != np.floor(n):
        raise ValueError("Input right-most shape ({}) does not "
                         "correspond to a triangular matrix.".format(m))
      n = np.int32(n)
      static_final_shape = x.shape[:-1].concatenate([n, n])
    else:
      m = tf.shape(input=x)[-1]
      # For derivation, see above. Casting automatically lops off the 0.5, so we
      # omit it.  We don't validate n is an integer because this has
      # graph-execution cost; an error will be thrown from the reshape, below.
      n = tf.cast(
          tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)), dtype=tf.int32)
      static_final_shape = tensorshape_util.with_rank_at_least(
          x.shape, 1)[:-1].concatenate([None, None])
    return n

  def _forward(self, x):
    x = tf.convert_to_tensor(value=x, name="x")
    # Find n such that the last dimension of x equals n(n+1)/2.
    n = self._get_triangular_n(x)

    batch_shape = tf.shape(x)[:-1]
    rows = []

    # Map the first row. This is just [1, 0,... 0].
    tril_part = tf.ones(_make_shape(batch_shape, 1), dtype=x.dtype)
    triu_part = tf.zeros(_make_shape(batch_shape, n), dtype=x.dtype)
    rows.append(tf.concat([tril_part, triu_part], axis=-1))

    # Map the rest of the rows.
    # TODO convert to tf.while_loop.
    for k in range(1, n + 1):
      # Take the next k elements, run it through the row bijector to create the
      # lower triangular (incl. diagonal) part of the kth row. The rest of the
      # row's elements are filled with zeros.
      start = tf.cast((k * (k - 1)) / 2, tf.int32)
      end = tf.cast((k * (k + 1)) / 2, tf.int32)

      tril_part = self._row_bijector.forward(x[..., start:end])
      triu_part = tf.zeros(_make_shape(batch_shape, n - k), dtype=x.dtype)
      row = tf.concat([tril_part, triu_part], axis=-1)
      rows.append(row)

    return tf.stack(rows, axis=-2)

  def _inverse(self, y):
    n = _get_last_dimension(y)

    x_tril_parts = []
    # TODO convert to tf.while_loop; maybe using TensorArray.
    # NOTE: TensorArray doesn't support concat-enating along axis=-1.
    # So we would have to transpose the elements before writing it.

    # Ignore the first row.
    for k in range(1, n):
      tril_part = self._row_bijector.inverse(y[..., k, :k + 1])
      x_tril_parts.append(tril_part)

    return tf.concat(x_tril_parts, axis=-1)

  def _inverse_log_det_jacobian(self, y):
    return -self._forward_log_det_jacobian(self._inverse(y))

  def _forward_log_det_jacobian(self, x):
    # Just sum all the constituent fldjs for each row.
    n = self._get_triangular_n(x) + 1
    batch_shape = tf.shape(x)[:-1]

    def add_row_fldj(i, current_sum):
      start = (i * (i - 1)) / 2
      end = (i * (i + 1)) / 2
      fldj = self._row_bijector.forward_log_det_jacobian(
          x[..., start:end], event_ndims=1)
      return i + 1, current_sum + fldj

    _, fldj_sum = tf.while_loop(
      cond=lambda i, *args: i < n,
      body=add_row_fldj,
      loop_vars=[
          tf.ones([], tf.int32, name='iter'),
          tf.zeros(batch_shape, dtype=x.dtype)
      ])

    return fldj_sum

Putting it all together

Generate fake data with a known correlation matrix.

Centered at mu=0 for simplicity.

corr3d = np.array([[ 1.00, -0.46,  0.18],
                   [-0.46,  1.00,  0.61],
                   [ 0.18,  0.61,  1.00]]).astype(np.float32)
scale_tril = tf.linalg.cholesky(corr3d)
dist = tfd.MultivariateNormalTriL(loc=0., scale_tril=scale_tril)

num_samples = 10000
data3d = dist.sample(num_samples)

Define our simple model's log posterior using JointDistributionSequential

def log_prob(sigma):
  """Computes `joint_log_prob` pinned at `data3d`."""

  lkj_and_mvn = tfd.JointDistributionSequential([
      tfd.LKJ(dimension=3, concentration=1.1, input_output_cholesky=True),
      lambda s: tfd.MultivariateNormalTriL(loc=0., scale_tril=s),
  ])

  # reshape to [num_observed, 1, num_dims] so that we can transparently handle
  # multiple chains.
  data_reshaped = data3d[:, tf.newaxis, :]

  # log_prob will have [num_observed, num_chains]. Sum across all observed
  # data points to output one scalar log probability per chain.
  return tf.reduce_sum(lkj_and_mvn.log_prob([sigma, data_reshaped]), axis=0)

Set up MCMC kernel with the new LKJ bijector.

@tf.function
def sample(num_chains, num_results, num_burnin_steps):
  """Samples from the model."""
  hmc = tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=log_prob,
      num_leapfrog_steps=10,
      step_size=0.005)

  initial_state = [
      # Start with fully uncorrellated components. This corresponds to the
      # correllation being the identity which also maps to the identity matrix
      # in Cholesky space.
      tf.eye(3, batch_shape=[num_chains], name='init_sigma')
  ]

  # Contrain `sigma` to the Cholesky space of correlation matrices.
  constraining_bijectors = [
      TransformRowTriL(row_bijector=HalfSphere()), # sigma
  ]
  kernel = tfp.mcmc.TransformedTransitionKernel(
      inner_kernel=hmc, bijector=constraining_bijectors)

  samples, kernel_results = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=initial_state,
      kernel=kernel)

  acceptance_probs = tf.reduce_mean(
      tf.cast(kernel_results.inner_results.is_accepted, tf.float32), axis=0)

  return samples, acceptance_probs

Run the chain and inspect results.

samples = sample(num_chains=10, num_results=1000, num_burnin_steps=100)
sigma_samples = samples[0][0]
estimated_scale_tril = tf.reduce_mean(sigma_samples, axis=[0, 1])

print('Actual Value: ', scale_tril)
print('Estimated Value: ', estimated_scale_tril)

I got these results:

('Actual Value: ', <tf.Tensor: id=1435, shape=(3, 3), dtype=float32, numpy=
array([[ 1.        ,  0.        ,  0.        ],
       [-0.46000001,  0.88791889,  0.        ],
       [ 0.18000001,  0.7802515 ,  0.59900546]], dtype=float32)>)
('Estimated Value: ', <tf.Tensor: id=6046, shape=(3, 3), dtype=float32, numpy=
array([[ 1.        ,  0.        ,  0.        ],
       [-0.43991616,  0.89801389,  0.        ],
       [ 0.18703762,  0.73523128,  0.65143347]], dtype=float32)>)
skeydan commented 5 years ago

Hi Anudhyan,

thanks so much for the quick implementation! Unfortunately I have problems testing this locally... When I execute your code it runs fine, but on my own (appended below) I get NameError: name 'np' is not defined:

Error in py_call_impl(callable, dots$args, dots$keywords) :
  RuntimeError: Evaluation error: NameError: name 'np' is not defined

Detailed traceback:
  File "/home/key/anaconda3/envs/r-tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py", line 326, in sample_chain
    previous_kernel_results = kernel.bootstrap_results(current_state)
  File "/home/key/anaconda3/envs/r-tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py", line 346, in bootstrap_results
    transformed_init_state))
  File "/home/key/anaconda3/envs/r-tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/hmc.py", line 558, in bootstrap_results
    kernel_results = self._impl.bootstrap_results(init_state)
  File "/home/key/anaconda3/envs/r-tensorflow/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/metropolis_hastings.py", line 265, in bootstrap_results
    pkr = self.inner_kernel.bootstrap_results(init_state)
  File "/home/key/anaconda3/envs/r-tensorflow/lib/python

I assume I need to integrate the new files into an actual build instead of sourcing them locally, could that be?

(((Unfortunately I also need to figure out how to make my local TF build working again, as it seems like current TF does not work with bazel 0.25 and the only way to get earlier bazel builds seems to be building from source if I'm not mistaken)))

Anyway, thanks again!

Here is the code that throws that error (on actual sampling):

import tensorflow as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd=tfp.distributions

n_cafes = 20
cafe_id = tf.constant(r.d["cafe"], dtype=tf.int64) - 1
afternoon = tf.constant(r.d["afternoon"], dtype=tf.float32)
wait = tf.constant(r.d["wait"])

def model(cafe_id):
  return tfd.JointDistributionSequential(
    [
      # chol(rho), the prior correlation matrix between intercepts and slopes
      tfd.LKJ(2, 2, input_output_cholesky=True),
      # sigma, prior for the waiting time
      tfd.Sample(tfd.Exponential(rate = 1), sample_shape=1),
      # sigma_cafe, prior of variances for intercepts and slopes (vector of 2)
      tfd.Sample(tfd.Exponential(rate = 1), sample_shape=2),
      # b, the prior for the slopes
      tfd.Sample(tfd.Normal(loc = -1, scale = 0.5), sample_shape=1),
      # a, the prior for the intercepts
      tfd.Sample(tfd.Normal(loc = 5, scale = 2), sample_shape=1),
      # mvn, multivariate distribution of intercepts and slopes
      # shape: batch size, 20, 2
      lambda a,b,sigma_cafe,sigma,chol_rho:
        tfd.Sample(
            tfd.MultivariateNormalTriL(
                loc = tf.concat([a,b], axis = -1),
                scale_tril=tf.linalg.LinearOperatorDiag(sigma_cafe).matmul(chol_rho)),
            sample_shape=n_cafes),
      # waiting time
      # shape should be batch size, 200
      lambda mvn, a, b, sigma_cafe, sigma: tfd.Independent(
        # need to pull out the correct cafe_id in the middle column
        tfd.Normal(
          loc=(tf.gather(mvn[..., 0], cafe_id, axis=-1) +
               tf.gather(mvn[..., 1], cafe_id, axis=-1) * afternoon), # Shape [batch, 200,]
          scale=sigma),  # Shape [batch,  1]
        reinterpreted_batch_ndims=1
      )
    ]
  )

m = model(cafe_id)
ds, xs = m.sample_distributions(3)

for d in ds:
  print(d)

s = m.sample(3)
m.log_prob(s)

# mcmc
constraining_bijectors = [
      TransformRowTriL(row_bijector=HalfSphere()), # sigma
      tfb.Identity(),
      tfb.Identity(),
      tfb.Identity(),
      tfb.Identity(),
      tfb.Identity(),
  ]

logp = lambda rho, sigma, sigma_cafe, b, a, mvn: m.log_prob([rho, sigma, sigma_cafe, b, a, mvn, wait])

number_of_steps = 500
burnin = 500
nchain = 4

rho, sigma, sigma_cafe, b, a, mvn, _ = m.sample(nchain)

def trace_fn(_, pkr):
  return (pkr.inner_results.is_accepted,
          pkr.inner_results.accepted_results.step_size)

hmc=tfp.mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn=logp,
  num_leapfrog_steps=3,
  step_size=.1)
)

kernel = tfp.mcmc.TransformedTransitionKernel(
      inner_kernel=hmc, bijector=constraining_bijectors)

kernel=tfp.mcmc.SimpleStepSizeAdaptation(
  kernel,
  target_accept_prob=.8,
  num_adaptation_steps=burnin
)

@tf.function
def run_mcmc():
  mcmc_trace, (is_accepted, step_size) = tfp.mcmc.sample_chain(
    num_results = number_of_steps,
    num_burnin_steps = burnin,
    current_state=[rho, sigma, sigma_cafe, b, a, mvn],
    kernel=kernel,
    trace_fn=trace_fn)
  return mcmc_trace, (is_accepted, step_size)

mcmc_trace, (is_accepted, step_size)=run_mcmc()
bloops commented 5 years ago

Ah, I think I forgot to import numpy in the original code. Did you already try adding this import: import numpy as np right above import tensorflow as tf?

skeydan commented 5 years ago

Yeah, I did...

bloops commented 5 years ago

Hi @skeydan I have modified your code and added on top of your earlier colab to include the new bijector. (Also, I've fixed a dtype bug in the new bijector.)

https://colab.sandbox.google.com/drive/1BTnfyl_XtDHW-GpFawWOsQw4O3Gukukx

It runs the chain now. Although i haven't checked whether it mixes well.

There were (probably unrelated) problems with getting trace_fn to work correctly so I removed that part in the colab for now.

skeydan commented 5 years ago

Great, thanks!! Can’t test now but I’ll checktomorrow!

skeydan commented 5 years ago

Hi @bloops thanks again. I found a way to test the bijectors locally with my original code (jumping from R to Python to R), and I'll describe the results below.

I've also inserted the actual data into the colab if you'd like to test for yourself? These are simulated data, and the true negative correlation is -0.7. The code in the the colab recovers a correlation of -0.5, which is I think not close enough yet given the data is so clean :-)

(((As an aside, the trace_fn needed an additional inner_kernel because of the additional TransformedKernel., I must have forgotten to change that in the Python version I uploaded)))

As to the sampling behavior, I see the following now. The upper row, columns 1 to 4, has the components of rho, then come 3 sigmas and then, the partially pooled mvn parameters.

image

Comparing this to how it looked without the bijectors

image

it's visible that the rhos look better, but the chains downstream in the model now mix less well.

Would you have an idea why this could be?

skeydan commented 5 years ago

Oops I may be comparing apples with oranges here. The non-bijector version used 1000 steps and burnin steps each (as chains were only starting to converge after ~ 400 steps overall), while the bijector one uses 500. I'll re-test using 1000 now.

skeydan commented 5 years ago

Hm, so with 1000 steps the chains look better (see below), but I still get a mean correlation of ~0.5.

image

Also, the posterior distribution looks too certain now (the ellipse used to be a lot wider):

image

Just speculating, but could it be that the bijectors somehow remove too many degrees of freedom?

skeydan commented 5 years ago

Me again. I see that Stan too arrives at a reconstructed correlation of -0.5 - so this should be okay :-)

Now comparing 500 steps in Stan against 500 steps with TFP, I see the following pooling behavior:

First, these are the empirical estimates and shrunk parameters for intercept and slope, which look ia bit different between TFP and Stan: TFP

image

Stan image

But on the final outcome scale the results look very similar:

TFP image

Stan image

Overall I's I'd say this runs well, although I think I'd like to experiment with the number of steps more, and with initial step sizes.

Do you think the bijectors could be added to master ? (my current workflow is a bit awkward with the switching between files/languages)

bloops commented 5 years ago

Hi Sigrid, that's awesome progress and I'm glad you got it working locally! It's exciting to see this come together :)

I've also inserted the actual data into the colab if you'd like to test for yourself? Sure. Where is the colab that you mentioned with the real data? Is it edited in the same link I sent?

Regarding why the reconstructed correlation is -0.5 while the real correlation is -0.7. Could this be because your LKJ concentration prior is too strong? If you set the concentration to 1.0 (or say 1.1) then it should be less strong of a regularizer and the reconstruction should be closer to the real one.

I didn't understand what you mean by empirical estimates and shrunk parameters of the slope and intercept. Did you mean the posterior estimates of the a and b variables using the samples from the chain?

One reason for the discrepancy between Stan and TFP could be, that for 500 steps, TFP chains doesn't look like they're mixing well. Especially the correlation matrix (rho) parameter. Can you try Stan and TFP with 1000 steps? Increasing the num_leapfrog_steps (to 10 or so) in the TFP HMC kernel should help with the mixing too.

Although, it is also possible that the new LKJ bijector is not as good as Stan's and that's causing the TFP version to perform worse.

Yes, I think I can get this checked into master, pending code reviews etc. It should take a week or so.

skeydan commented 5 years ago

Hi Anudhyan,

thanks, that sounds great! I can perform further tests on Wednesday.

Regarding why the reconstructed correlation is -0.5 while the real correlation is -0.7. Could this be because your LKJ concentration prior is too strong?

I took that from the Stan code this is based on, and Stan too arrives at ~ 0.5... So in that regard the behavior is similar, although I need to do some more testing (for example, on one run, the posterior correlation ended up at ~ -0.35, which again is different...)

One reason for the discrepancy between Stan and TFP could be, that for 500 steps, TFP chains doesn't look like they're mixing well. Especially the correlation matrix (rho) parameter. Can you try Stan and TFP with 1000 steps? Increasing the num_leapfrog_steps (to 10 or so) in the TFP HMC kernel should help with the mixing too.

Yes, I have that impression too (((somehow it's a bit ambiguous now, partly the results look similar to those in Stan and partly they don't ...))) ... need to do more systematic tests and will try your suggestions...

Where is the colab that you mentioned with the real data? Is it edited in the same link I sent?

Yes, I inserted the data in the colab you modified last (originally created by Chris).

bloops commented 5 years ago

I don't think I can see your edits in the colab. Here is the link to the colab I was using. Perhaps you can copy it and share the link to your edited version?

skeydan commented 5 years ago

Oh, sorry, get it... I saved it to some sandbox repo and added the simulated data there:

https://github.com/skeydan/sb/blob/master/jds/cafe_wait_times_with_LKJ_bijector.ipynb

(((everything else is unchanged there right now, I made the edits in a text editor)))

I also wanted to add the Stan model I'm comparing to here:

data{
  vector[200] wait;
  int afternoon[200];
  int cafe[200];
}
parameters{
  vector[20] b_cafe;
  vector[20] a_cafe;
  real a;
  real b;
  vector<lower=0>[2] sigma_cafe;
  real<lower=0> sigma;
  corr_matrix[2] Rho;
}
model{
  vector[200] mu;
  Rho ~ lkj_corr( 2 );
  sigma ~ exponential( 1 );
  sigma_cafe ~ exponential( 1 );
  b ~ normal( -1 , 0.5 );
  a ~ normal( 5 , 2 );
  {
    vector[2] YY[20];
    vector[2] MU;
    MU = [ a , b ]';
    for ( j in 1:20 ) YY[j] = [ a_cafe[j] , b_cafe[j] ]';
    YY ~ multi_normal( MU , quad_form_diag(Rho , sigma_cafe) );
  }
  for ( i in 1:200 ) {
    mu[i] = a_cafe[cafe[i]] + b_cafe[cafe[i]] * afternoon[i];
  }
  wait ~ normal( mu , sigma );
}
skeydan commented 5 years ago

For a more systematic comparison to the results I'm trying to reproduce, here first are the most relevant characteristics of the above Stan model on the data. The Stan model has 1000 steps overall (burnin + "real").

Posterior summary statistics (excerpt)

The rhos here are reported in terms of correlation (although internally Stan works with the cholesky).

# a              3.65 0.22  3.31  4.02  2703    1
# b             -1.13 0.14 -1.36 -0.91  2359    1
# sigma_cafe[1]  0.96 0.16  0.74  1.23  2110    1
# sigma_cafe[2]  0.59 0.12  0.42  0.80  1786    1
# sigma          0.47 0.03  0.43  0.51  2103    1
# Rho[1,1]       1.00 0.00  1.00  1.00   NaN  NaN
# Rho[1,2]      -0.50 0.18 -0.75 -0.19  2454    1
# Rho[2,1]      -0.50 0.18 -0.75 -0.19  2454    1
# Rho[2,2]       1.00 0.00  1.00  1.00  2062    1

Trace plots (excerpt)

I think the chains mix very well already during the burnin phase (grey). Also the effective sample sizes are very high (they are reported as sums though).

image image

Shrinkage on parameter and outcome scales

image

image

skeydan commented 5 years ago

Here are the corresponding (in spirit) plots from TFP.

Run 1 500 burnin, 500 real, leapfrog_steps 3, step_size 0.1 for all, LKJ parameter == 2

Posterior summary statistics (excerpt)

Rhos here are in cholesky terms. Mean rho as correlation is: -0.2463341.

   key            mean     sd  lower  upper    ess   rhat
   <chr>         <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
 1 rho_1         1     0       1      1     NaN    NaN   
 2 rho_2         0     0       0      0     NaN    NaN   
 3 rho_3        -0.294 0.359  -0.723  0.693   6.58   1.53
 4 rho_4         0.880 0.0998  0.699  1.000   9.65   1.33
 5 sigma         0.507 0.103   0.401  0.631 255.     1.35
 6 sigma_cafe_1  0.884 0.196   0.522  1.32   12.1    1.34
 7 sigma_cafe_2  0.496 0.172   0.215  0.809  11.2    2.43
 8 b            -1.12  0.123  -1.33  -0.843  20.1    1.23
 9 a             3.64  0.230   3.14   4.05   11.6    1.27
10 a_cafe_1      4.20  0.203   3.75   4.55   12.1    1.25

Traceplots (in different order, starting with rhos, the sigma, sigma_cafe [1:2], b, a)

These are from the "real phase" only (not showing the burnin phase)

image

Shrinkage plots

The second looks fine I think, and shrinkage is visible in both, but the negative correlation between intercept and slope (a and b) could perhaps be a bit stronger in plot 2.

image

image

skeydan commented 5 years ago

Hey, wow!!! I think I can skip the planned experiments!! :-) What I did now was trying again to add in the Exp bijectors for the sigmas which threw errors in a prior version (don't recall exactly; - probably this was without your bijectors) - so now this is what I have :-)

Look:

Mean rho is -0.4283794.

  key            mean     sd  lower  upper   ess   rhat
   <chr>         <dbl>  <dbl>  <dbl>  <dbl> <dbl>  <dbl>
 1 rho_1         1     0       1      1     NaN   NaN   
 2 rho_2         0     0       0      0     NaN   NaN   
 3 rho_3        -0.551 0.172  -0.818 -0.144  37.9   1.04
 4 rho_4         0.809 0.109   0.628  1.000  31.3   1.04
 5 sigma         0.474 0.0268  0.426  0.529 440.    1.01
 6 sigma_cafe_1  0.970 0.170   0.669  1.32   53.6   1.02
 7 sigma_cafe_2  0.592 0.123   0.364  0.813  60.1   1.05
 8 b            -1.14  0.138  -1.39  -0.856  99.6   1.00
 9 a             3.66  0.212   3.27   4.07   62.5   1.02
10 a_cafe_1      4.21  0.206   3.78   4.59   47.0   1.01

image

image

image

Bottom line, @bloops you're a genius, I'll start writing my post, and I'll publish it as soon as your bijectors are on master :-)

bloops commented 5 years ago

That's awesome! :) Great detective work with identifying the problem and fixing the chain! Yes I will send out for review today and I think we can get it checked in soon.

On Wed, May 15, 2019, 12:57 AM Sigrid Keydana notifications@github.com wrote:

Hey, wow!!! I think I can skip the planned experiments!! :-) What I did now was trying again to add in the Exp bijectors for the sigmas which threw errors in a prior version (don't recall exactly; - probably this was without your bijectors) - so now this is what I have :-)

Look:

Mean rho is -0.4283794.

key mean sd lower upper ess rhat

1 rho_1 1 0 1 1 NaN NaN 2 rho_2 0 0 0 0 NaN NaN 3 rho_3 -0.551 0.172 -0.818 -0.144 37.9 1.04 4 rho_4 0.809 0.109 0.628 1.000 31.3 1.04 5 sigma 0.474 0.0268 0.426 0.529 440. 1.01 6 sigma_cafe_1 0.970 0.170 0.669 1.32 53.6 1.02 7 sigma_cafe_2 0.592 0.123 0.364 0.813 60.1 1.05 8 b -1.14 0.138 -1.39 -0.856 99.6 1.00 9 a 3.66 0.212 3.27 4.07 62.5 1.02 10 a_cafe_1 4.21 0.206 3.78 4.59 47.0 1.01 [image: image] [image: image] [image: image] Bottom line, @bloops you're a genius, I'll start writing my post, and I'll publish it as soon as your bijectors are on master :-) — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub , or mute the thread .
skeydan commented 5 years ago

wonderful :-)

bloops commented 5 years ago

Although, mean rho being -0.42 is a bit far from Stan's -0.5. Although much much better than what it was before your fix!

skeydan commented 5 years ago

Hm, wait ... perhaps there is still room for improvement? The bijectors operate at the cholesky level, right? could it be that somehow on the correlation level, the constraints are not strong enough yet? The actual correlations (for the same data I hard coded in

https://github.com/skeydan/sb/blob/master/jds/cafe_wait_times_with_LKJ_bijector.ipynb

, if you'd like to test) look like this - the diagonal elements are != 1, instead, they add to 1:

> rhos[1:5]
[[1]]
           [,1]       [,2]
[1,]  1.4123573 -0.4922588
[2,] -0.4922588  0.5876426

[[2]]
           [,1]       [,2]
[1,]  1.3136871 -0.4639909
[2,] -0.4639909  0.6863130

[[3]]
           [,1]       [,2]
[1,]  1.5377637 -0.4985719
[2,] -0.4985719  0.4622364

[[4]]
           [,1]       [,2]
[1,]  1.4497577 -0.4974693
[2,] -0.4974693  0.5502423

[[5]]
           [,1]       [,2]
[1,]  1.5506017 -0.4974329
[2,] -0.4974329  0.4493982
bloops commented 5 years ago

Ah, that's strange. Just to double check, are the 500 samples of the cholesky-space correlations first matmul-ed (with their transpose) and then you are taking the mean of the correlation matrices? So there would be 500 matmuls for each chain, these can be combined into one batch matmul to make the total time a bit faster.

If it's done in the other order, first take mean of cholesky-space matrices and then create the correlation matrix, then such artifacts could show up.

skeydan commented 5 years ago

It's not that (my way is way one, or should be) but there must be something wrong in my calculation - the notebook has the correct correlation!! So my bad, sorry!

(Just need to find where my stuff is wrong - I'm doing all the post-processing on the R side, and one has to be super careful with shape manipulations...)

back soon...

skeydan commented 5 years ago

Oh, no!!! found the mistake (((won't tell though ;-)))

so, all my bad, your bijector is perfect

bloops commented 5 years ago

Awesome! Thanks and great work! Looking forward to the blog post 🙂

skeydan commented 5 years ago

Thanks to you!! The post will need the bijectors to be on master, so I can add regular wrappers for them, but I've already started writing :-)

skeydan commented 5 years ago

Hi @bloops might I ask when the bijectors will arrive on master? Blog post's ready :-)

bloops commented 5 years ago

It's currently in review... I am expecting it to land by end of this week!

skeydan commented 5 years ago

great, thanks!!

bloops commented 5 years ago

This is now checked in. But it is a bit different from the code I had shared earlier. Can you add the wrappers and make this work with your new blog post?

https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/correlation_cholesky.py

skeydan commented 5 years ago

Thanks @bloops I already saw earlier today, by chance :-)

For I-don't-quite-understand-yet-why reasons I can't seem to be able to make a manual build of TFP work together with the most recent TF nightly (when calling from R, specifically, I mean) so I may have to test tomorrow when I can get the nightly builds for TF and TFP both together (also, starting a build of TF now but that takes a while on my laptop).

But yeah, very much looking forward to testing, - as long as the functionality is as it was it should work :-) Thanks!

bloops commented 5 years ago

Ah, I see. Good luck with the import issues. I don't have much experience with the calling code from R, but perhaps others would be able to help -- let us know!

skeydan commented 5 years ago

Hi @bloops thank you again, but unfortunately now this throws an error for me (which was not there with the 2 original bijectors):

    ValueError: Requires start <= limit when delta > 0: 1/0 for 'mcmc_sample_chain_1/simple_step_size_adaptation___init__/_bootstrap_results/transformed_kernel_bootstrap_results/mh_bootstrap_results/hmc_kernel_bootstrap_results/maybe_call_fn_and_grads/value_and_gradients/correlation_cholesky_1/forward_log_det_jacobian/range_1' (op: 'Range') with input shapes: [], [], [] and with computed input tensors: input[0] = <1>, input[1] = <0>, input[2] = <1>.

Here is the Python version of the code that throws the error:

import numpy as np

import tensorflow as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd=tfp.distributions
tfb=tfp.bijectors

n_cafes = 20
cafe_id = tf.constant(r.d["cafe"], dtype=tf.int64) - 1
afternoon = tf.constant(r.d["afternoon"], dtype=tf.float32)
wait = tf.constant(r.d["wait"])

def model(cafe_id):
  return tfd.JointDistributionSequential(
    [
      # chol(rho), the prior correlation matrix between intercepts and slopes
      tfd.LKJ(2, 2, input_output_cholesky=True),
      # sigma, prior for the waiting time
      tfd.Sample(tfd.Exponential(rate = 1), sample_shape=1),
      # sigma_cafe, prior of variances for intercepts and slopes (vector of 2)
      tfd.Sample(tfd.Exponential(rate = 1), sample_shape=2),
      # b, the prior for the slopes
      tfd.Sample(tfd.Normal(loc = -1, scale = 0.5), sample_shape=1),
      # a, the prior for the intercepts
      tfd.Sample(tfd.Normal(loc = 5, scale = 2), sample_shape=1),
      # mvn, multivariate distribution of intercepts and slopes
      # shape: batch size, 20, 2
      lambda a,b,sigma_cafe,sigma,chol_rho:
        tfd.Sample(
            tfd.MultivariateNormalTriL(
                loc = tf.concat([a,b], axis = -1),
                scale_tril=tf.linalg.LinearOperatorDiag(sigma_cafe).matmul(chol_rho)),
            sample_shape=n_cafes),
      # waiting time
      # shape should be batch size, 200
      lambda mvn, a, b, sigma_cafe, sigma: tfd.Independent(
        # need to pull out the correct cafe_id in the middle column
        tfd.Normal(
          loc=(tf.gather(mvn[..., 0], cafe_id, axis=-1) +
               tf.gather(mvn[..., 1], cafe_id, axis=-1) * afternoon), # Shape [batch, 200,]
          scale=sigma),  # Shape [batch,  1]
        reinterpreted_batch_ndims=1
      )
    ]
  )

m = model(cafe_id)
# ds, xs = m.sample_distributions(3)
# 
# for d in ds:
#   print(d)
# 
# s = m.sample(3)
# m.log_prob(s)

# mcmc

constraining_bijectors = [
      tfb.CorrelationCholesky(),
      tfb.Exp(),
      tfb.Exp(),
      tfb.Identity(),
      tfb.Identity(),
      tfb.Identity(),
  ]

logp = lambda rho, sigma, sigma_cafe, b, a, mvn: m.log_prob([rho, sigma, sigma_cafe, b, a, mvn, wait])

number_of_steps = 500
burnin = 500
#number_of_steps = 1000
#burnin = 1000
nchain = 4

rho, sigma, sigma_cafe, b, a, mvn, _ = m.sample(nchain)

def trace_fn(_, pkr):
  return (pkr.inner_results.inner_results.is_accepted,
          pkr.inner_results.inner_results.accepted_results.step_size)

hmc=tfp.mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn=logp,
  num_leapfrog_steps=3,
  step_size=.1)

kernel = tfp.mcmc.TransformedTransitionKernel(
      inner_kernel=hmc, bijector=constraining_bijectors)

kernel=tfp.mcmc.SimpleStepSizeAdaptation(
  kernel,
  target_accept_prob=.8,
  num_adaptation_steps=burnin
)

@tf.function
def run_mcmc():
  mcmc_trace, (is_accepted, step_size) = tfp.mcmc.sample_chain(
    num_results = number_of_steps,
    num_burnin_steps = burnin,
    current_state=[rho, sigma, sigma_cafe, b, a, mvn],
    kernel=kernel,
    trace_fn=trace_fn)
  return mcmc_trace, (is_accepted, step_size)

mcmc_trace, (is_accepted, step_size)=run_mcmc()

ess = tfp.mcmc.effective_sample_size(mcmc_trace)
rhat = tfp.mcmc.potential_scale_reduction(mcmc_trace)

lkj_samples = mcmc_trace[0]
print('LKJ samples have shape: ', lkj_samples.shape)

estimated_correlation_matrix = np.zeros([2, 2], dtype=np.float32)
for i in range(number_of_steps):
  for c in range(nchain):
    cholesky_corr = lkj_samples[i,c,:,:]
    corr = tf.matmul(cholesky_corr, tf.transpose(cholesky_corr))
    estimated_correlation_matrix += corr

estimated_correlation_matrix /= number_of_steps * nchain

print('Estimated correlation matrix: ', estimated_correlation_matrix)

In case you would want to reproduce this with the actual data, I hardcoded them here

https://github.com/skeydan/sb/blob/master/jds/cafe_wait_times_with_LKJ_bijector.ipynb

Here's a more detailed stack trace

    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/sample.py:326 sample_chain
        previous_kernel_results = kernel.bootstrap_results(current_state)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/simple_step_size_adaptation.py:409 bootstrap_results
        inner_results = self.inner_kernel.bootstrap_results(init_state)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py:346 bootstrap_results
        transformed_init_state))
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/hmc.py:560 bootstrap_results
        kernel_results = self._impl.bootstrap_results(init_state)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/metropolis_hastings.py:265 bootstrap_results
        pkr = self.inner_kernel.bootstrap_results(init_state)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/hmc.py:772 bootstrap_results
        ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/internal/util.py:233 maybe_call_fn_and_grads
        result, grads = _value_and_gradients(fn, fn_arg_list, result, grads)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/internal/util.py:192 _value_and_gradients
        result = fn(*fn_arg_list)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py:203 new_target_log_prob
        event_ndims=event_ndims)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py:50 fn
        for b, e, sp in zip(bijector, event_ndims, transformed_state_parts)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py:50 <listcomp>
        for b, e, sp in zip(bijector, event_ndims, transformed_state_parts)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/bijectors/bijector.py:1294 forward_log_det_jacobian
        return self._call_forward_log_det_jacobian(x, event_ndims, name, **kwargs)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/bijectors/bijector.py:1264 _call_forward_log_det_jacobian
        kwargs=kwargs)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/bijectors/bijector.py:1065 _compute_inverse_log_det_jacobian_with_caching
        event_ndims)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/bijectors/bijector.py:1359 _reduce_jacobian_det_over_event
        axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow_probability/python/bijectors/bijector.py:1371 _get_event_reduce_dims
        return tf.range(-reduce_ndims, 0)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1321 range
        return gen_math_ops._range(start, limit, delta, name=name)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py:7584 _range
        "Range", start=start, limit=limit, delta=delta, name=name)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:788 _apply_op_helper
        op_def=op_def)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:464 create_op
        compute_device=compute_device)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:3623 create_op
        op_def=op_def)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:2034 __init__
        control_input_ops)
    /home/key/anaconda3/envs/r-tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:1874 _create_c_op
        raise ValueError(str(e))

    ValueError: Requires start <= limit when delta > 0: 1/0 for 'mcmc_sample_chain_1/simple_step_size_adaptation___init__/_bootstrap_results/transformed_kernel_bootstrap_results/mh_bootstrap_results/hmc_kernel_bootstrap_results/maybe_call_fn_and_grads/value_and_gradients/correlation_cholesky_1/forward_log_det_jacobian/range_1' (op: 'Range') with input shapes: [], [], [] and with computed input tensors: input[0] = <1>, input[1] = <0>, input[2] = <1>.
skeydan commented 5 years ago

@bloops just wanted to let you know, I'll be offline for > 20 hours soon so I won't be able to respond immediately if you answer this in the meantime :-) (... but of course I very much look forward to get the post publishable ;-))

bloops commented 5 years ago

I think I found the underlying issue -- I'm working on a fix. Meanwhile, the problem appears to be only in graph mode, so if you remove the @tf.function annotation for now, it should work.

skeydan commented 5 years ago

great, thanks!! I'm leaving in 20 mins so can't test & publish now but then I can perhaps still publish Friday late afternoon this week :-)

bloops commented 5 years ago

Hi, I've submitted a proposed fix: https://github.com/tensorflow/probability/commit/123170326fee7f408b30993c2ca6cd432a632cf3

skeydan commented 5 years ago

Thank you! I tested and I get the same results / performance as with the 2 bijectors before, so I'm happy :-)

... and I just published: https://blogs.rstudio.com/tensorflow/posts/2019-05-24-varying-slopes/

:-)

bloops commented 5 years ago

Blog post looks great! Awesome :)

skeydan commented 5 years ago

Thank you!! And thanks for coding the bijector so quickly! It was great working with you :-)

srvasude commented 5 years ago

Thanks @bloops for working on this, and thanks @skeydan for this wonderful blogpost. Going to close this bug, since I think the work items are done here :).