tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

What is the rationale behind the current implementation of default_multivariate_normal_fn? #775

Open nbro opened 4 years ago

nbro commented 4 years ago

The current implementation of default_multivariate_normal_fn looks like (excluding doc-strings)

def default_multivariate_normal_fn(dtype, shape, name, trainable, add_variable_fn):
  del name, trainable, add_variable_fn   # unused
  dist = normal_lib.Normal(loc=tf.zeros(shape, dtype), scale=dtype.as_numpy_dtype(1))
  batch_ndims = tf.size(dist.batch_shape_tensor())
  return independent_lib.Independent(dist, reinterpreted_batch_ndims=batch_ndims)

If name, trainable and add_variable_fn are unused, why do you even require them to be passed as arguments? Furthermore, you initialize the scale parameter of the Normal distribution with dtype.as_numpy_dtype(1), which is as weird as it can be.

Why not simply have a method that returns a Normal initialized with 0s as means and 1s as scales, and require only the shape of that distribution?

Also, why does this function even exist, if you already have default_mean_field_normal_fn, which returns a closure def _fn(dtype, shape, name, trainable, add_variable_fn), which thus has the same parameters as default_multivariate_normal_fn and does more or less the same thing (apart from the fact that _fn uses its parameters to define the loc and scale, rather than deleting more than half of them). For consistency, wouldn't it be better to just use default_mean_field_normal_fn to also initialize the prior (by providing default parameters when calling it, or maybe have an alias, but not another function that does the same thing), or am I missing something?

nbro commented 4 years ago

For consistency, wouldn't it be better to just use default_mean_field_normal_fn to also initialize the prior (by providing default parameters when calling it, or maybe have an alias, but not another function that does the same thing), or am I missing something?

The problem is: if we use default_mean_field_normal_fn to initialise the prior, the prior will be trainable, which may be undesirable!

default_multivariate_normal_fn creates a non-trainable Gaussian (that's why it deletes the parameter trainable). However, the API is really badly designed and inconsistent. There should be only the need for one function that creates either a trainable or non-trainable Gaussian (posterior or prior). There's no need for two functions that do so similar things. You can also have default_non_trainable_normal_prior_fn, but that should just be an alias for a call to a single function with specific parameters (and not a completely different implementation/function).