tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

training bijectors with tf.optimizers.Adam in tensorflow 2.0 #628

Closed Char-Aznable closed 5 years ago

Char-Aznable commented 5 years ago

Hi, I want to train a normalizing flow using Adam. My model looks like this:

# define bijectors here...
flow = tfb.Chain(list(reversed(bijectors)))
baseDist = tfd.Independent(
    distribution=tfd.Uniform(low=tf.zeros([nMer, 3], dtype=tf.float32),
                             high=tf.ones([nMer, 3], dtype=tf.float32)),
    reinterpreted_batch_ndims=2)
dist = tfd.TransformedDistribution(
    distribution=baseDist,
    bijector=flow)

optimizerReverse = tf.optimizers.Adam(1e-2)
lossReverse = tfp.vi.monte_carlo_variational_loss(
    system.log_prob,
    dist,
    discrepancy_fn=tfp.vi.kl_reverse,
    sample_size=128)
trainReverse = optimizerReverse.minimize(lossReverse, dist.trainable_variables)

This gives the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-18-05774bebf465> in <module>
      5     discrepancy_fn=tfp.vi.kl_reverse,
      6     sample_size=128)
----> 7 trainReverse = optimizerReverse.minimize(lossReverse, dist.trainable_variables)

~/.linuxbrew/opt/python/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py in minimize(self, loss, var_list, grad_loss, name)
    315     """
    316     grads_and_vars = self._compute_gradients(
--> 317         loss, var_list=var_list, grad_loss=grad_loss)
    318 
    319     return self.apply_gradients(grads_and_vars, name=name)

~/.linuxbrew/opt/python/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py in _compute_gradients(self, loss, var_list, grad_loss)
    349       if not callable(var_list):
    350         tape.watch(var_list)
--> 351       loss_value = loss()
    352     if callable(var_list):
    353       var_list = var_list()

TypeError: 'tensorflow.python.framework.ops.EagerTensor' object is not callable

I used to be able to get similar training op defined in tensorflow 1.0 tf.compat.v1.train.AdamOptimizer. It seems tensorflow 2.0's adam optimize require explicitly specifying trainable variables and I don't know how to get the trainable variables from tensorflow probability. Both flow and dist have empty trainable_variables:

In [22]: flow.trainable_variables                
Out[22]: ()

In [23]: dist.trainable_variables                
Out[23]: ()
csuter commented 5 years ago

The minimize function expects a 0-argument callable argument in its first position, but tfp.vi.monte_carlo_variational_loss returns a Tensor (the output of the loss calculation). I think you need to pass

lambda: tfp.vi.monte_carlo_variational_loss(...)

instead. This will be a function (by way of closure) of the trainable variables in your model, and will re-sample from the variational model on each call in the optimizer loop.

davmre commented 5 years ago

Note that tfp.vi.fit_surrogate_posterior exists as lightweight sugar implementing the solution Chris described.

Dave

On Mon, Oct 28, 2019 at 12:54 PM Christopher Suter notifications@github.com wrote:

The minimize function expects a 0-argument callable argument in its first position, but tfp.vi.monte_carlo_variational_loss returns a Tensor (the output of the loss calculation). I think you need to pass

lambda: tfp.vi.monte_carlo_variational_loss(...)

instead. This will be a function (by way of closure) of the trainable variables in your model, and will re-sample from the variational model on each call in the optimizer loop.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/628?email_source=notifications&email_token=AAHSFCVK2RPXESE274ODFHDQQ47QHA5CNFSM4JF7EER2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOECOFY2Y#issuecomment-547118187, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAHSFCTISRBQPZA77QX7D4LQQ47QHANCNFSM4JF7EERQ .

Char-Aznable commented 5 years ago

@csuter @davmre Thanks for the information! I manage to get it to run by passing lambda: tfp.vi.monte_carlo_variational_loss(...) to the minimize() call. But I observe a 100x slow down compared to the tensorflow-1.0 compatible version. I open a new issue #629 and I'm closing this.