Open psyphh opened 4 years ago
Thanks psyphh! I think this is a bug on our end. While we look into it more, a workaround would be to put x = mvn_model.sample(...)
inside the GradientTape
, would that work for now?
Thanks for the reply. I found the problem still existed. In addition, since this code is for maximum likelihood estimation, the gradient tape should be put in a loop. It might not be reasonable to resample x in each iteration. If there is another way to write a MLE code, please tell me. Thank you again.
This indeed is an issue and present in all variants of MVN.
MultivariateNormalTriL
(and other mvn variants such as MultivariateNormalDiag
) are child classes of MultivariateNormalLinearOperator
which is a child class of TransformedDistribution
What I see here in MultivariateNormalLinearOperator
, https://github.com/tensorflow/probability/blob/4735f4948e95c139c25124c59820432bb30f82a8/tensorflow_probability/python/distributions/mvn_linear_operator.py#L207
super(MultivariateNormalLinearOperator, self).__init__(
# TODO(b/137665504): Use batch-adding meta-distribution to set the batch
# shape instead of tf.zeros.
# We use `Sample` instead of `Independent` because `Independent`
# requires concatenating `batch_shape` and `event_shape`, which loses
# static `batch_shape` information when `event_shape` is not statically
# known.
distribution=sample.Sample(
normal.Normal(
loc=tf.zeros(batch_shape, dtype=dtype),
scale=tf.ones([], dtype=dtype)),
event_shape),
bijector=bijector,
validate_args=validate_args,
name=name)
is that the base distribution is passed loc=tf.zeros...
and scale=tf.ones
Shouldn't the distribution parameter (of normal.Normal) here receive the user/caller supplied loc
and scale
?
Most likely I am not following the code flow here so would appreciate if you can look at this bug as such. I am worried if this bug is present then shouldn't it impact all tfp layers ?
Regards Kapil
Re the original issue: it looks like the behavior occurs only when x
is sampled outside the loop from the same distribution whose density we compute inside the loop. That is, if I slightly modify the first line of @psyphh 's example to sample x
from a different normal distribution:
x = tfd.Normal(0., 1.).sample([1000, 1, dims])
optimizer = tf.optimizers.Adam(learning_rate=1.)
with tf.GradientTape() as tape:
loss_value = -tf.reduce_mean(tf.reduce_sum(mvn_model.prob(x), axis = 1))
print(tape.watched_variables())
gradients = tape.gradient(loss_value, mvn_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, mvn_model.trainable_variables))
then the tape will provide gradients to both loc
and scale
.
It even works if I change the first line to x = mvn_model.sample([1000, 1]).numpy()
, i.e., to sample from the same distribution but force the round-trip between a tf.EagerTensor
and np.array
. In fact, even adding an extra op x = x + 0.
or x = tf.identity(x)
seems to be enough to make things work.
It seems as though in the original example, TF somehow remembers that the loc mu
was added to x
inside the MVN sample()
call, so that when x - mu
is computed inside of prob
, the gradient cancels out. I'm a bit confused how this could be happening, since the sampling occurs outside the tape (it's possible there's an underlying bug in TF). For now, does the x = tf.identity(x)
workaround work for you?
Thanks @davmre
Your suggested workaround does make it work
I have created a notebook to show various variants - https://colab.research.google.com/drive/1eqzgoxhg8vHFZK7sboMq_fZxFbBoJt0b?usp=sharing
Since this works, clearly I do not understand how TransformedDistrution
behaves or variables are tracked.
I was trying to debug this and it led me to see the implementation of various MVN variants that inherit from TransformedDistribution. As I mentioned in the comment above when passing the distribition
to super class (i.e. TransformedDistribution
) a standard normal is being passed. i.e loc and scale are not being passed.
Based on this I concluded that loc/mu is not being tracked and hence the error. See the snippet again here -
distribution=sample.Sample(
normal.Normal(
loc=tf.zeros(batch_shape, dtype=dtype),
scale=tf.ones([], dtype=dtype)),
event_shape),
As a comparison when I see the implementation of LogNormal
, I see something that I can understand which is
super(LogNormal, self).__init__(
distribution=normal.Normal(loc=loc, scale=scale),
bijector=exp_bijector.Exp(),
validate_args=validate_args,
parameters=parameters,
name=name)
Here in case of LogNormal, the caller provided loc and scale are passed to base distribution (TransformedDistribution
) which is not the case with MVN variants !
I am really puzzled how MVN variants get learnable loc and scale.
Would really appreciate if you can guide and explain the inner workings here (especially for MVN)
Regards & thanks Kapil
MultivariateNormal distributions in TFP inherit from TransformedDistribution
; they're defined by linear transformation of a 'base' standard Normal
distribution. You're looking at the definition of the base distribution, but loc
and scale
are stored (and tracked) by the bijector
that represents the linear transformation, defined here.
In general, .trainable_variables
for any Distribution or Bijector is populated using magic from the tf.Module
superclass that recursively examines all properties of an object searching for tf.Variable
instances. So if dist
is an MVN instance, it will recurse through dist.bijector
and eventually find the variables dist.bijector.components[0].scale
and dist.bijector.components[1].loc
, and collect them as dist.trainable_variables
.
Makes sense now. Can sleep better tonight :)
Much appreciated.
Regards Kapil
Dear developers, Recently, I am trying to write code for calculating MLE via TFP. I found that TFP will not track the
loc
parameter of multivariate normal when usingGradientTape
Here is an example code:
The
mu
andraw_scale_tril
are both learnable, which can be checked byprint(mvn_model.trainable_variables)
However, in the training process with the following code
I found that the
mu
is not watched by tape anymore. As a result, the training process cannot be finished successfully. Similar code works for univariate normal. If developers could figure out whether it is a bug or what I miss, this will be really helpful.Best,