The IndependentReparametrizationSampler.sample method creates a tf.Variable the first time it's called:
if self._eps is None:
self._eps = tf.Variable(sample_eps())
This works ok, but causes some strange behaviour with TensorFlow compilation. Specifically, if a tf.Variable is created within a tf.function it causes the function to be traced (compiled) twice in a row, which affects speed.
I think this is due to the dependence of eps on the number of latent model dimensions, but it would be better if this variable creation could be avoid somehow. One suggestion is that the model event shapes could be part of the ProbabilisticModel interface. Something like:
The
IndependentReparametrizationSampler.sample
method creates a tf.Variable the first time it's called:This works ok, but causes some strange behaviour with TensorFlow compilation. Specifically, if a tf.Variable is created within a tf.function it causes the function to be traced (compiled) twice in a row, which affects speed.
I think this is due to the dependence of eps on the number of latent model dimensions, but it would be better if this variable creation could be avoid somehow. One suggestion is that the model event shapes could be part of the
ProbabilisticModel
interface. Something like: