Photrek / Nonlinear-Statistical-Coupling

Apache License 2.0
5 stars 1 forks source link

CoupledNormal - sample_n function #7

Closed Kevin-Chen0 closed 3 years ago

Kevin-Chen0 commented 3 years ago

Implement the sampling function for CoupledNormal. Use the commented out TFP's Student sample_n code as a foundation.

See our latest nsc code here.

See original StudentT code here:

def sample_n(n, df, loc, scale, batch_shape, dtype, seed):
  """Draw n samples from a Student T distribution.
  Note that `scale` can be negative or zero.
  The sampling method comes from the fact that if:
    X ~ Normal(0, 1)
    Z ~ Chi2(df)
    Y = X / sqrt(Z / df)
  then:
    Y ~ StudentT(df)
  Args:
    n: int, number of samples
    df: Floating-point `Tensor`. The degrees of freedom of the
      distribution(s). `df` must contain only positive values.
    loc: Floating-point `Tensor`; the location(s) of the distribution(s).
    scale: Floating-point `Tensor`; the scale(s) of the distribution(s). Must
      contain only positive values.
    batch_shape: Callable to compute batch shape
    dtype: Return dtype.
    seed: Optional seed for random draw.
  Returns:
    samples: a `Tensor` with prepended dimensions `n`.
  """
  normal_seed, gamma_seed = samplers.split_seed(seed, salt='student_t')
  shape = ps.concat([[n], batch_shape], 0)

  normal_sample = samplers.normal(shape, dtype=dtype, seed=normal_seed)
  df = df * tf.ones(batch_shape, dtype=dtype)
  gamma_sample = gamma_lib.random_gamma(
      [n], concentration=0.5 * df, rate=0.5, seed=gamma_seed)
  samples = normal_sample * tf.math.rsqrt(gamma_sample / df)
  return samples * scale + loc