tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

(complementary) cumulative distribution for use in JointDistributionSequential #501

Closed skeydan closed 4 years ago

skeydan commented 5 years ago

Hi,

I would like to construct a simple model, using JointDistributionSequential, for durations ,where part of the data is censored. The non-censored part can be modeled using tfd.Exponential, while the censored data should be modeled using the (exponential) complementary cumulative distribution (exponential_lccdf in Stan, see https://mc-stan.org/docs/2_19/functions-reference/exponential-distribution.html).

Would someone have a suggestion how to do this? I suppose that in the exponential case, as the ccdf is simple to construct, one would have to wrap the formula exp(- lambda * x)in a TransformedDistribution somehow, or in a subclass of Distribution?

Many thanks for any help!

(BTW there's also a PyMC3 example https://docs.pymc.io/notebooks/bayes_param_survival_pymc3.html where they wrap the ccdf in a Potential class, but I don't think there is a comparable object in TFP...?)

junpenglao commented 5 years ago

I am not sure there is a similar Potential concept right now in JointDistribution*. My approach/workaround would be getting the logprob from a JointDistribution, and then add the potential, something like:

joint = JointDistributionSequential([...])
def log_prob(x):
    log_prob = joint.log_prob(x)
    potential = exponential_lccdf(x[...])  <== select one of the parameter to pass to a lccdf function
    return log_prob+potential
skeydan commented 5 years ago

Thanks a lot @junpenglao, I hadn't thought of that workaround! I'll try it and report back :-)

skeydan commented 5 years ago

Hi @junpenglao sorry for the delay, I couldn't work on this until something else was solved :-)

As a check if what I'm doing is correct, could you please take a look at the following code (please excuse the clumsy translation to Python ;-))?

Also, in case you'd say it's correct but could easily be made more elegant (in a language-independent way I mean, not the Python syntax), I'd be happy for any suggestions.

The model is meant to represent durations (censored and uncensored), stored in the variable check_time. Durations are exponentially distributed with mean 1/<a linear model with an intercept and predictors>. I'm using bijectors to make sure the rate doesn't get negative - with the below spec most chains mix pretty well, but some definitely need more tuning (which I haven't attempted yet).

Just wondering, would you have a better idea for dealing with rates for an exponential?

Many thanks!

# complete dataframe with all variables
df = r.df
# all rows where status == completed
df_c = r.df_c
# all rows where status != completed
df_nc = r.df_nc

def model(data): return tfd.JointDistributionSequential(
  [
    tfd.Normal(0, 1),
    tfd.Normal(0, 1),
    tfd.Normal(0, 1),
    tfd.Normal(0, 1),
    tfd.Normal(0, 1),
    tfd.Normal(0, 1),
    tfd.Normal(0, 1),
    lambda b6, b5, b4, b3, b2, b1, a:
      tfd.Independent(
        tfd.Exponential(
          rate = 1/(a[:, None] + b1[:, None] * data.depends + b2[:, None] * data.imports +
                 b3[:, None] * data.doc_size + b4[:, None] * data.r_size + b5[:, None] * data.ns_export + 
                 b6[:, None] * data.ns_import)
        ), reinterpreted_batch_ndims = 1)
  ]
)

unconstraining_bijectors = [
  tfb.Exp(),
  tfb.Exp(),
  tfb.Exp(),
  tfb.Exp(),
  tfb.Exp(),
  tfb.Exp(),
  tfb.Exp(),
  tfb.Identity()
]

def get_exponential_lccdf(a, b1, b2, b3, b4, b5, b6, data):
  e = tfd.Independent(
        tfd.Exponential(
          rate = 1/(a[:, None] + b1[:, None] * data.depends + b2[:, None] * data.imports +
                 b3[:, None] * data.doc_size + b4[:, None] * data.r_size + b5[:, None] * data.ns_export + 
                 b6[:, None] * data.ns_import)
        ), reinterpreted_batch_ndims = 1)
  cum_prob = e.cdf(data.check_time)
  return (1- cum_prob)

# construct model for uncensored data only
m = model(df_c)

# this function is written to allow for easy comparison of resulting parameters 
# dependent on whether we use the potential or not
def get_log_prob (model_data, censored_data = None):
  def log_prob(a, b1, b2, b3, b4, b5, b6):
    lp = m.log_prob([a, b1, b2, b3, b4, b5, b6, model_data.check_time])
    potential =  get_exponential_lccdf(a, b1, b2, b3, b4, b5, b6, censored_data) if censored_data is not None else 0
    return (lp + potential)
  return log_prob

log_prob = get_log_prob(df_c, df_nc)
junpenglao commented 5 years ago

I would use matrix operation and juggling the shape a bit to cut down some lines:

# place holder data for simulation.
data_array = np.hstack([np.ones((1000, 1)),
                        np.random.randn(1000, 6)])

def model(data_array): 
  return tfd.JointDistributionSequential(
  [
    tfd.Sample(
        tfd.Normal(loc=0., scale=1.),
        sample_shape=7),
    lambda betas:
      tfd.Independent(
        tfd.Exponential(
          rate = 1/tf.transpose(
              tf.matmul(tf.cast(data_array, dtype=betas.dtype), 
                        tf.transpose(betas)))
        ), reinterpreted_batch_ndims = 1)
  ]
)

unconstraining_bijectors = [
  tfb.Exp(),
  tfb.Identity()
]

def get_exponential_lccdf(betas, data_array, check_time):
  e = tfd.Independent(
        tfd.Exponential(
            rate = 1/tf.transpose(
                tf.matmul(tf.cast(data_array, dtype=betas.dtype), 
                          tf.transpose(betas)))
        ), reinterpreted_batch_ndims = 1)
  cum_prob = e.cdf(check_time)
  return (1- cum_prob)

# construct model for uncensored data only
m = model(data_array)

# this function is written to allow for easy comparison of resulting parameters 
# dependent on whether we use the potential or not
def get_log_prob (model_data, data_array, censored_data = None):
  def log_prob(betas):
    lp = m.log_prob([betas, model_data.check_time])
    potential =  get_exponential_lccdf(betas, censored_data) if censored_data is not None else 0
    return (lp + potential)
  return log_prob
skeydan commented 5 years ago

That is a lot nicer, thank you!!

skeydan commented 5 years ago

Hi Junpeng,

thanks again :-) I'd have a final question (see below for complete code with shape annotations)...

With the above refactoring, I had to transpose the target array(s), e.g.

# after transpose: (1, 13523)
check_time_c = tf.transpose(r.check_time_c.to_numpy())

otherwise I'd get an error during logprob calculation while sampling.

To be precise, this would still work:

m = model(df_c)
samples = m.sample(2)
m.log_prob(samples)

but then during sampling I'd get

   ValueError: Dimensions must be equal, but are 4 and 13523 for 'mcmc_sample_chain_1/simple_step_size_adaptation___init__/_bootstrap_results/transformed_kernel_bootstrap_results/mh_bootstrap_results/hmc_kernel_bootstrap_results/maybe_call_fn_and_grads/value_and_gradients/JointDistributionSequential/log_prob/mcmc_sample_chain_1_simple_step_size_adaptation___init____bootstrap_results_transformed_kernel_bootstrap_results_mh_bootstrap_results_hmc_kernel_bootstrap_results_maybe_call_fn_and_grads_value_and_gradients_JointDistributionSequential_log_prob_Independentmcmc_sample_chain_1_simple_step_size_adaptation___init____bootstrap_results_transformed_kernel_bootstrap_results_mh_bootstrap_results_hmc_kernel_bootstrap_results_maybe_call_fn_and_grads_value_and_gradients_JointDistributionSequential_log_prob_Exponential/log_prob/mcmc_sample_chain_1_simple_step_size_adaptation___init____bootstrap_results_transformed_kernel_bootstrap_results_mh_bootstrap_results_hmc_kernel_bootstrap_results_maybe_call_fn_and_grads_value_and_gradients_JointDistributionSequential_log_prob_Exponential/log_prob/mul' (op: 'Mul') with input shapes: [4,13523], [13523,1].

Honestly I don't really understand, as the output shape is not affected by the refactoring... Would you have an idea? Thanks!

import tensorflow as tf
import tensorflow_probability as tfp
tfd=tfp.distributions
tfb=tfp.bijectors
mcmc=tfp.mcmc

tf.compat.v2.enable_v2_behavior()

# all rows where status == completed
# shape (13523, 7)
df_c = r.df_c.to_numpy() 
# all rows where status != completed
# shape (103, 7)
df_nc = r.df_nc.to_numpy()

# after transpose: (1, 13523)
check_time_c = tf.transpose(r.check_time_c.to_numpy())
# after transpose: (1, 103)
check_time_nc = tf.transpose(r.check_time_nc.to_numpy())

def model(data): return tfd.JointDistributionSequential(
  [
    tfd.Sample(tfd.Normal(0, 1), sample_shape= 7),
    lambda betas:
      tfd.Independent(
        tfd.Exponential(
          rate = 1/tf.transpose(
              tf.matmul(tf.cast(data, dtype=betas.dtype), 
                        tf.transpose(betas)))
        ), reinterpreted_batch_ndims = 1)
  ]
)

m = model(df_c)
samples = m.sample(2)
m.log_prob(samples)

unconstraining_bijectors = [
  tfb.Exp(),
  tfb.Identity()
]

def get_exponential_lccdf(betas, data, target):
  e = tfd.Independent(
        tfd.Exponential(
            rate = 1/tf.transpose(
                tf.matmul(tf.cast(data, dtype=betas.dtype), 
                          tf.transpose(betas)))
        ), reinterpreted_batch_ndims = 1)
  cum_prob = e.cdf(tf.cast(target, dtype=betas.dtype))
  return tf.math.log((1- cum_prob))

def get_log_prob (target_c, censored_data = None, target_nc = None):
  def log_prob(betas):
    lp = m.log_prob([betas, tf.cast(target_c, betas.dtype)])
    potential =  get_exponential_lccdf(betas, censored_data, target_nc) if censored_data is not None else 0
    return (lp + potential)
    return lp
  return log_prob

log_prob = get_log_prob(check_time_c, df_nc, check_time_nc)

n_chains = 4
n_burnin = 1000
n_steps = 1000

initial_betas = m.sample(n_chains)[0]

hmc = mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn = log_prob,
  num_leapfrog_steps = 6,
  step_size = tf.constant([0.1, 0.1, 0.1, 0.3, 0.15, 0.15, 0.3])
)
transformed_kernel = mcmc.TransformedTransitionKernel(inner_kernel = hmc, bijector = unconstraining_bijectors) 
kernel = mcmc.SimpleStepSizeAdaptation(inner_kernel=transformed_kernel, target_accept_prob = 0.8, num_adaptation_steps = n_burnin)

@tf.function()
def run_mcmc():
  return mcmc.sample_chain(
    num_results = n_steps,
    num_burnin_steps = n_burnin,
    kernel = kernel,
    current_state = tf.ones_like(initial_betas),
    trace_fn = lambda state, pkr: 
      [pkr.inner_results.inner_results.is_accepted,
       pkr.inner_results.inner_results.accepted_results.step_size
       ]
  )

res = run_mcmc()
skeydan commented 4 years ago

hi @junpenglao sorry for bothering again, I'd have an additional, more urgent question... The above model seemed to work pretty well, but I'm told that in order to model durations (as in survival analysis), the rate should be calculated as

1/tf$exp(tf.transpose(
              tf.matmul(tf.cast(data, dtype=betas.dtype), 
                        tf.transpose(betas)))

instead. But when I try that (removing the bijector, although that doesn't make a difference) the model does not work at all (all parameter samples keep their initial value, no steps accepted, in spite of trying different step sizes and initial values). Would you have an idea what could be the reason? Many thanks!!

junpenglao commented 4 years ago

Hi there sorry about the slow response - but regarding your more pressing issue, the exp() could make the value extremely large and thus making the log_prob return nan. Try evaluating the log_prob function (and ideally its gradient), and adjust the initial value accordingly.

skeydan commented 4 years ago

Hi @junpenglao thank you!! Whatever I did, I haven't been able to make the log_prob return something else than nan on the initial samples. in the original model - the one without the additional tf.exp - that was the case too, but there the bijectors helped to make HMC work.

Even with minimal step sizes of 1e-10 etc. I cannot get even a single step accepted in the new model. Honestly I'm not yet completely convinced I do need that tf.exp() in there...

I've extracted the Stan code from brms, an R package that calls Stan in the background, and it looks like this:

functions {
}
data {
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=-1,upper=2> cens[N];  // indicates censoring
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // population-level effects
  real temp_Intercept;  // temporary intercept
}
transformed parameters {
}
model {
  vector[N] mu = temp_Intercept + Xc * b;
  for (n in 1:N) {
    mu[n] = exp(-(mu[n]));
  }
  // priors including all constants
  target += student_t_lpdf(temp_Intercept | 3, 4, 10);
  // likelihood including all constants
  if (!prior_only) {
    for (n in 1:N) {
      // special treatment of censored data
      if (cens[n] == 0) {
        target += exponential_lpdf(Y[n] | mu[n]);
      } else if (cens[n] == 1) {
        target += exponential_lccdf(Y[n] | mu[n]);
      } else if (cens[n] == -1) {
        target += exponential_lcdf(Y[n] | mu[n]);
      }
    }
  }
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = temp_Intercept - dot_product(means_X, b);

Would you have an opinion which of the TFP models

# this model uses a tfb.Exp bijector
def model(data): return tfd.JointDistributionSequential(
  [
    tfd.Sample(tfd.Normal(0, 1), sample_shape= 7),
    lambda betas:
      tfd.Independent(
        tfd.Exponential(
          rate = 1/tf.transpose(
              tf.matmul(tf.cast(data, dtype=betas.dtype), 
                        tf.transpose(betas)))
        ), reinterpreted_batch_ndims = 1)
  ]
)
# no bijector
def model(data): return tfd.JointDistributionSequential(
  [
    tfd.Sample(tfd.Normal(0, 1), sample_shape= 7),
    lambda betas:
      tfd.Independent(
        tfd.Exponential(
          rate = 1/tf.math.exp(tf.transpose(
              tf.matmul(tf.cast(data, dtype=betas.dtype), 
                        tf.transpose(betas))))
        ), reinterpreted_batch_ndims = 1)
  ]
)

that conforms to, in spirit? (The brms model converges very well.)

Thanks again!

junpenglao commented 4 years ago

Hi @junpenglao thank you!! Whatever I did, I haven't been able to make the log_prob return something else than nan on the initial samples. in the original model - the one without the additional tf.exp - that was the case too, but there the bijectors helped to make HMC work.

Yep that's precisely the root cause - no matter how small you set the step size if the initial return value is nan HMC is not going to work. The one with bijector actually transform the initial value to reasonable range first - in another word, to make it work for the new model, you need to set the initial value to something reasonable, and you should keep the bijector as your model still contains bound variable and HMC only really works on variable defined on the real line.

skeydan commented 4 years ago

Hi @junpenglao thanks again :-) Unfortunately, the second model does not work even if I use the original bijector... In case you could spare the time, could you take a look?

I've hardcoded some values in the snippet below (this is actually the complete censored portion, 103 rows, and it's really not that different from the uncensored one...)

In the below code, I've varied initial values and step sizes but I never get a single HMC step accepted... Many thanks in any case!

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
tfd=tfp.distributions
tfb=tfp.bijectors
mcmc=tfp.mcmc

tf.compat.v2.enable_v2_behavior()

# all rows where status != completed
# shape (103, 7)
# df_c = r.df_c_py
df_c = np.array([[1.000000e+00, 0.000000e+00, 5.000000e+00, 0.000000e+00,
        1.262100e-02, 8.000000e+00, 1.000000e+00],
       [1.000000e+00, 2.000000e+00, 1.000000e+00, 5.521000e-03,
        2.758400e-02, 4.000000e+00, 1.000000e+00],
       [1.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00,
        1.535000e-02, 1.000000e+00, 0.000000e+00],
       [1.000000e+00, 3.000000e+00, 1.200000e+01, 0.000000e+00,
        1.983483e+00, 2.970000e+02, 0.000000e+00],
       [1.000000e+00, 3.000000e+00, 0.000000e+00, 0.000000e+00,
        1.716650e-01, 2.000000e+00, 1.000000e+00],
       [1.000000e+00, 0.000000e+00, 4.000000e+00, 0.000000e+00,
        1.051600e-02, 1.600000e+01, 3.000000e+00],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 0.000000e+00,
        1.796000e-02, 1.000000e+00, 4.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00,
        1.459100e-02, 1.000000e+00, 1.000000e+00],
       [1.000000e+00, 6.000000e+00, 0.000000e+00, 0.000000e+00,
        3.203600e-02, 1.400000e+01, 1.300000e+01],
       [1.000000e+00, 4.000000e+00, 2.000000e+00, 0.000000e+00,
        6.047300e-02, 9.000000e+00, 3.000000e+00],
       [1.000000e+00, 1.000000e+00, 4.000000e+00, 1.665000e-02,
        8.791400e-02, 6.000000e+00, 2.600000e+01],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 3.321000e-03,
        4.770000e-04, 1.000000e+00, 1.000000e+00],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 0.000000e+00,
        4.020000e-02, 0.000000e+00, 0.000000e+00],
       [1.000000e+00, 4.000000e+00, 0.000000e+00, 1.282000e-03,
        2.715000e-03, 4.000000e+00, 1.000000e+00],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 0.000000e+00,
        2.269500e-02, 0.000000e+00, 0.000000e+00],
       [1.000000e+00, 0.000000e+00, 0.000000e+00, 2.271000e-03,
        1.007000e-03, 0.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 9.000000e+00, 2.174500e-02,
        1.045630e-01, 1.400000e+01, 9.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 7.345000e-03,
        2.734700e-02, 1.000000e+01, 1.400000e+01],
       [1.000000e+00, 1.000000e+00, 1.000000e+01, 6.077000e-03,
        2.504800e-02, 9.000000e+00, 1.200000e+01],
       [1.000000e+00, 1.000000e+00, 8.000000e+00, 0.000000e+00,
        2.872720e-01, 2.100000e+01, 7.800000e+01],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 6.562600e-02,
        2.662590e-01, 2.100000e+01, 7.900000e+01],
       [1.000000e+00, 1.000000e+00, 5.000000e+00, 1.126100e-02,
        4.537100e-02, 0.000000e+00, 1.000000e+01],
       [1.000000e+00, 1.000000e+00, 1.300000e+01, 0.000000e+00,
        1.185700e-02, 1.200000e+01, 0.000000e+00],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 1.320500e-02,
        1.399400e-02, 4.000000e+00, 1.000000e+01],
       [1.000000e+00, 1.000000e+00, 8.000000e+00, 1.562200e-02,
        1.131680e-01, 9.000000e+00, 1.700000e+01],
       [1.000000e+00, 4.000000e+00, 0.000000e+00, 0.000000e+00,
        1.834900e-02, 3.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.000000e+01, 0.000000e+00,
        1.062500e-02, 1.000000e+01, 0.000000e+00],
       [1.000000e+00, 0.000000e+00, 8.000000e+00, 0.000000e+00,
        1.058470e-01, 9.000000e+00, 2.000000e+01],
       [1.000000e+00, 1.000000e+00, 5.000000e+00, 0.000000e+00,
        1.762700e-02, 5.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 8.000000e+00, 0.000000e+00,
        1.331400e-02, 1.000000e+00, 4.000000e+00],
       [1.000000e+00, 3.000000e+00, 3.000000e+00, 0.000000e+00,
        2.340300e-02, 5.000000e+00, 3.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.400000e+01, 5.908000e-02,
        6.863370e-01, 1.120000e+02, 1.000000e+02],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00,
        6.965000e-03, 3.000000e+00, 2.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 0.000000e+00,
        1.476640e-01, 8.000000e+00, 1.000000e+00],
       [1.000000e+00, 1.000000e+00, 2.200000e+01, 0.000000e+00,
        9.268000e-02, 2.000000e+00, 2.400000e+01],
       [1.000000e+00, 1.000000e+00, 8.000000e+00, 3.395300e-02,
        1.108800e-01, 2.400000e+01, 4.900000e+01],
       [1.000000e+00, 1.000000e+00, 2.000000e+00, 7.750000e-03,
        6.094200e-02, 4.000000e+00, 1.100000e+01],
       [1.000000e+00, 2.000000e+00, 1.100000e+01, 0.000000e+00,
        1.751220e-01, 2.400000e+01, 3.400000e+01],
       [1.000000e+00, 0.000000e+00, 6.000000e+00, 1.927500e-02,
        9.076000e-03, 1.100000e+01, 0.000000e+00],
       [1.000000e+00, 3.000000e+00, 5.000000e+00, 0.000000e+00,
        7.833980e-01, 3.800000e+01, 3.360000e+02],
       [1.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00,
        8.092000e-03, 1.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 7.000000e+00, 0.000000e+00,
        8.413000e-03, 8.000000e+00, 4.000000e+00],
       [1.000000e+00, 1.000000e+00, 2.000000e+00, 0.000000e+00,
        4.934000e-03, 3.000000e+00, 5.000000e+00],
       [1.000000e+00, 4.000000e+00, 3.000000e+00, 0.000000e+00,
        2.814200e-02, 3.000000e+00, 1.000000e+00],
       [1.000000e+00, 1.000000e+00, 3.000000e+00, 0.000000e+00,
        1.418300e-02, 2.000000e+00, 2.800000e+01],
       [1.000000e+00, 1.000000e+00, 9.000000e+00, 0.000000e+00,
        9.025700e-02, 1.200000e+01, 1.000000e+00],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 2.680800e-02,
        1.983000e-02, 0.000000e+00, 3.000000e+00],
       [1.000000e+00, 3.000000e+00, 2.000000e+00, 0.000000e+00,
        1.358800e-02, 9.000000e+00, 1.300000e+01],
       [1.000000e+00, 3.000000e+00, 0.000000e+00, 0.000000e+00,
        4.375100e-02, 4.000000e+00, 3.000000e+00],
       [1.000000e+00, 1.000000e+00, 5.000000e+00, 1.343700e-02,
        3.384100e-02, 0.000000e+00, 4.000000e+00],
       [1.000000e+00, 0.000000e+00, 2.000000e+00, 0.000000e+00,
        8.287000e-03, 5.000000e+00, 6.000000e+00],
       [1.000000e+00, 6.000000e+00, 3.000000e+00, 1.175400e-02,
        3.975400e-02, 5.000000e+00, 1.600000e+01],
       [1.000000e+00, 3.000000e+00, 1.500000e+01, 0.000000e+00,
        4.987710e-01, 1.020000e+02, 1.020000e+02],
       [1.000000e+00, 1.000000e+00, 5.000000e+00, 7.910900e-02,
        1.425550e-01, 5.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.300000e+01, 1.416300e-02,
        4.128300e-02, 4.100000e+01, 9.000000e+00],
       [1.000000e+00, 1.000000e+00, 0.000000e+00, 2.453300e-02,
        1.935200e-02, 0.000000e+00, 1.200000e+01],
       [1.000000e+00, 5.000000e+00, 1.100000e+01, 4.428300e-02,
        1.555380e-01, 1.400000e+01, 4.000000e+00],
       [1.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00,
        1.231690e-01, 7.000000e+00, 2.300000e+01],
       [1.000000e+00, 7.000000e+00, 0.000000e+00, 3.292200e-02,
        6.440000e-02, 2.000000e+00, 1.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 5.220000e-03,
        4.277800e-02, 2.500000e+01, 3.300000e+01],
       [1.000000e+00, 2.000000e+00, 3.000000e+00, 0.000000e+00,
        1.239810e-01, 5.000000e+00, 0.000000e+00],
       [1.000000e+00, 6.000000e+00, 3.000000e+00, 0.000000e+00,
        6.515800e-02, 4.000000e+00, 1.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 6.302000e-03,
        2.263700e-02, 3.000000e+00, 6.000000e+00],
       [1.000000e+00, 5.000000e+00, 1.000000e+00, 4.244800e-02,
        2.148092e+00, 1.000000e+01, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.800000e+01, 0.000000e+00,
        1.811130e-01, 3.600000e+01, 1.500000e+01],
       [1.000000e+00, 0.000000e+00, 1.300000e+01, 1.597900e-02,
        8.130000e-02, 5.400000e+01, 8.000000e+00],
       [1.000000e+00, 0.000000e+00, 6.000000e+00, 0.000000e+00,
        4.465700e-02, 9.000000e+00, 5.000000e+00],
       [1.000000e+00, 1.000000e+00, 2.000000e+00, 8.296000e-03,
        3.941000e-02, 0.000000e+00, 5.000000e+00],
       [1.000000e+00, 0.000000e+00, 8.000000e+00, 0.000000e+00,
        9.251700e-02, 1.700000e+01, 1.900000e+01],
       [1.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00,
        8.827000e-03, 1.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 1.945700e-02,
        1.066270e-01, 1.800000e+01, 2.800000e+01],
       [1.000000e+00, 1.000000e+00, 5.000000e+00, 0.000000e+00,
        2.005400e-01, 3.000000e+00, 1.000000e+00],
       [1.000000e+00, 0.000000e+00, 3.000000e+00, 0.000000e+00,
        3.771000e-03, 3.000000e+00, 1.000000e+00],
       [1.000000e+00, 0.000000e+00, 1.300000e+01, 0.000000e+00,
        4.131400e-02, 2.800000e+01, 9.000000e+00],
       [1.000000e+00, 0.000000e+00, 7.000000e+00, 1.007600e-02,
        3.496900e-02, 1.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 7.000000e+00, 9.670000e-03,
        1.911800e-02, 1.600000e+01, 4.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.400000e+01, 0.000000e+00,
        2.045150e-01, 1.500000e+01, 2.100000e+01],
       [1.000000e+00, 2.000000e+00, 0.000000e+00, 0.000000e+00,
        5.520000e-03, 2.000000e+00, 1.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 0.000000e+00,
        3.069200e-02, 5.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 4.989200e-02,
        9.109500e-02, 2.000000e+00, 3.200000e+01],
       [1.000000e+00, 1.000000e+00, 4.000000e+00, 0.000000e+00,
        2.196200e-02, 1.600000e+01, 7.000000e+00],
       [1.000000e+00, 2.000000e+00, 8.000000e+00, 0.000000e+00,
        5.391200e-02, 5.000000e+00, 8.000000e+00],
       [1.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00,
        1.951800e-02, 1.000000e+00, 1.500000e+01],
       [1.000000e+00, 0.000000e+00, 5.000000e+00, 3.120000e-04,
        2.385000e-02, 1.700000e+01, 1.600000e+01],
       [1.000000e+00, 1.000000e+00, 1.200000e+01, 4.229500e-02,
        8.784700e-02, 1.000000e+00, 2.900000e+01],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 1.042000e-02,
        3.990800e-02, 1.000000e+00, 0.000000e+00],
       [1.000000e+00, 4.000000e+00, 2.000000e+00, 0.000000e+00,
        2.188900e-02, 2.000000e+00, 0.000000e+00],
       [1.000000e+00, 1.000000e+00, 0.000000e+00, 9.689000e-03,
        1.798300e-02, 0.000000e+00, 4.200000e+01],
       [1.000000e+00, 2.000000e+00, 1.000000e+00, 2.652000e-02,
        2.191340e-01, 5.000000e+00, 0.000000e+00],
       [1.000000e+00, 5.000000e+00, 1.200000e+01, 8.357800e-02,
        2.387910e-01, 1.600000e+01, 1.000000e+00],
       [1.000000e+00, 0.000000e+00, 5.000000e+00, 6.750000e-03,
        4.953200e-02, 1.400000e+01, 1.100000e+01],
       [1.000000e+00, 1.000000e+00, 5.000000e+00, 5.909900e-02,
        8.554900e-02, 0.000000e+00, 1.000000e+00],
       [1.000000e+00, 2.000000e+00, 1.100000e+01, 2.933200e-02,
        1.623150e-01, 1.800000e+01, 4.000000e+01],
       [1.000000e+00, 1.000000e+00, 1.000000e+01, 5.667500e-02,
        2.108470e-01, 0.000000e+00, 2.700000e+01],
       [1.000000e+00, 2.000000e+00, 8.000000e+00, 0.000000e+00,
        1.404870e-01, 9.000000e+00, 9.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.100000e+01, 1.478700e-02,
        1.411650e-01, 1.400000e+01, 7.400000e+01],
       [1.000000e+00, 3.000000e+00, 3.000000e+00, 1.144310e-01,
        7.841860e-01, 1.100000e+01, 9.000000e+00],
       [1.000000e+00, 1.000000e+00, 3.000000e+00, 0.000000e+00,
        7.365000e-03, 0.000000e+00, 2.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 0.000000e+00,
        1.720000e-02, 1.000000e+00, 2.000000e+00],
       [1.000000e+00, 1.000000e+00, 1.500000e+01, 5.153100e-02,
        1.997540e-01, 1.000000e+00, 1.600000e+01],
       [1.000000e+00, 0.000000e+00, 6.000000e+00, 0.000000e+00,
        7.004700e-02, 5.000000e+00, 3.400000e+01],
       [1.000000e+00, 1.000000e+00, 1.000000e+00, 1.210900e-02,
        8.348500e-02, 1.000000e+00, 0.000000e+00],
       [1.000000e+00, 2.000000e+00, 1.200000e+01, 3.165700e-02,
        2.924370e-01, 1.500000e+01, 1.290000e+02]])
# all rows where status == completed
# shape (13523, 7)
# df_nc = r.df_nc_py
# just reuse df_c
df_nc = df_c

# after transpose: (1, 103)
check_time_c = tf.transpose(r.check_time_c)
check_time_c = tf.constant(np.array([[ 10.,  80.,  38., 304.,  89.,  68.,   5.,  77.,  89., 174., 113.,
          5.,  51.,  93.,  53., 144., 110., 139.,  93., 143., 308., 271.,
        237.,  78., 174.,  54., 178.,  97.,   5.,  76.,  95., 696.,  67.,
        170.,  79., 118.,  62., 114., 183., 142.,  42.,   7.,  80.,  82.,
         67., 173.,  62.,  58.,  34.,  81.,  73., 115., 491., 265., 212.,
         88., 158.,  96., 115., 101.,  46.,  83.,  56., 253., 100., 130.,
         66.,  67., 172.,  32., 620.,  61.,   6., 116., 298., 144., 116.,
         43.,  98., 234., 120., 165.,  43.,  62., 183., 276.,  74., 101.,
        346., 350., 140.,   6., 355., 100.,  87., 103., 448.,  59.,  63.,
        269.,  96.,  67., 299.]]))

# after transpose: (1, 13523)
#check_time_nc = tf.transpose(r.check_time_nc)
check_time_nc = check_time_c

def model(data): return tfd.JointDistributionSequential(
  [
    tfd.Sample(tfd.Normal(0, 1), sample_shape= 7),
    lambda betas:
      tfd.Independent(
        tfd.Exponential(
          rate = 1/tf.math.exp(tf.transpose(
              tf.matmul(tf.cast(data, dtype=betas.dtype), 
                        tf.transpose(betas)))
        )), reinterpreted_batch_ndims = 1)
  ]
)

m = model(df_nc)
samples = m.sample(2)
m.log_prob(samples)

unconstraining_bijectors = [
  tfb.Exp(),
  tfb.Identity()
]

def get_exponential_lccdf(betas, data, target):
  e = tfd.Independent(
        tfd.Exponential(
            rate = 1/tf.exp(tf.transpose(
                tf.matmul(tf.cast(data, dtype=betas.dtype), 
                          tf.transpose(betas))))
        ), reinterpreted_batch_ndims = 1)
  cum_prob = e.cdf(tf.cast(target, dtype=betas.dtype))
  return tf.math.log(1- cum_prob)

def get_log_prob (target_nc, censored_data = None, target_c = None):
  def log_prob(betas):
    lp = m.log_prob([betas, tf.cast(target_nc, betas.dtype)])
    potential =  get_exponential_lccdf(betas, censored_data, target_c) if censored_data is not None else 0
    return (lp + potential)
  return log_prob

log_prob = get_log_prob(check_time_nc, df_c, check_time_c)

n_chains = 4
n_burnin = 1000
n_steps = 1000

initial_betas = m.sample(n_chains)[0]

hmc = mcmc.HamiltonianMonteCarlo(
  target_log_prob_fn = log_prob,
  num_leapfrog_steps = 6,
  # this was for the model without the tf.math.exp
  #step_size = tf.constant([0.1, 0.1, 0.1, 0.3, 0.15, 0.15, 0.3]
  step_size = tf.constant(0.001)
)
transformed_kernel = mcmc.TransformedTransitionKernel(inner_kernel = hmc, bijector = unconstraining_bijectors) 
kernel = mcmc.SimpleStepSizeAdaptation(inner_kernel=transformed_kernel, target_accept_prob = 0.8, num_adaptation_steps = n_burnin)

@tf.function()
def run_mcmc():
  return mcmc.sample_chain(
    num_results = n_steps,
    num_burnin_steps = n_burnin,
    kernel = kernel,
    #current_state = tf.ones_like(initial_betas),
    #current_state = tf.zeros_like(initial_betas),
    current_state = initial_betas,
    trace_fn = lambda state, pkr: 
      [pkr.inner_results.inner_results.is_accepted,
       pkr.inner_results.inner_results.accepted_results.step_size
       ]
  )

res = run_mcmc()
skeydan commented 4 years ago

Hi,

I decided to publish the post with the previous model that worked fine:

https://blogs.rstudio.com/tensorflow/posts/2019-07-31-censored-data/.

I think that should be ok too.

In case you still had ideas reg. the above, or the shape thing https://github.com/tensorflow/probability/issues/501#issuecomment-514105449, I would of course be happy to hear them, but I know you have things to work on so feel free to close :-)

Thanks again for your help!

junpenglao commented 4 years ago

No problem. FWIW I took a quick look, the problem is that the predictor matrix df_nc contains some pretty large outlier values, you can see it by doing plt.hist(df_nc.flatten(), 100); image As a result, the lambda parameter to the Exponential distribution is almost always contains inf after the transformation:

samples = m.sample(2)
1/tf.math.exp(tf.transpose(
              tf.matmul(tf.cast(df_nc, dtype=tf.float32), 
                        tf.transpose(samples[0]))))  # <== you can see the inf in the result

So, you can either standardized the df_nc, brms acutally do that in:

transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X
  vector[Kc] means_X;  // column means of X before centering
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}

or use a prior with much narrower standard deviation so that the sampled value does not make the prediction blow up. Alternative, you can supply all zeros as initial value to the sampler: initial_betas = tf.zeros_like(m.sample(n_chains)[0]) I think brms use all zeros as initial value for unconstrained random variable, that's why it does not have the same problem.

junpenglao commented 4 years ago

(Close this for now but please feel free to follow up if you have another question!)

skeydan commented 4 years ago

Thanks again! I had tried all of the above, also in combination of 2, but never all of them together. Then when in addition I removed the bijector it worked! :-) :-) :-)

Updated the post, - should be acceptable to statisticians (as opposed to DL people who are used to alchemy ;-)) now as well :-)

Thanks again!

junpenglao commented 4 years ago

Glad it helped! There are quite a bit of preprocessing/preconditioning in brms (and many other similar libraries) so that the inference could be ran automatically, and when you are using low level libraries like TFP or pymc3 you would need to spell out these boilerplate codes explicitly - it's not always trivial :-)

skeydan commented 4 years ago

Yeah! Here my main error in thinking was that, as I had first used normalization in the non-exp model but noticed that TFP seemed to be handling the substantial differences in scale between the predictors without problems, the exp model should work fine without as well, given that I experimented with standard deviations of 0.001, and was using 0 for initial values...

But using all three together made the difference :-)