Open jedisom opened 2 years ago
I think I found a workaround/fix for this. I wrote a wrapper/sub-class for tfp.layers.DenseVariational
that includes the get_config
method and now the clone_model
call in the code above seems to work if you replace tfp.layers.DenseVariational
with DenseVariationalFix
.
import tensorflow_probability as tfp
class DenseVariationalFix(tfp.layers.DenseVariational):
def __init__(self,
units,
make_posterior_fn,
make_prior_fn,
kl_weight=None,
kl_use_exact=False,
activation=None,
use_bias=True,
activity_regularizer=None,
**kwargs):
super(DenseVariationalFix, self).__init__(
units=units,
make_posterior_fn=make_posterior_fn,
make_prior_fn=make_prior_fn,
kl_weight=kl_weight,
kl_use_exact=kl_use_exact,
activation=activation,
use_bias=use_bias,
activity_regularizer=activity_regularizer,
**kwargs)
self._kl_weight = kl_weight
self._kl_use_exact = kl_use_exact
def get_config(self):
config ={
'units': self.units,
'make_posterior_fn': self._make_posterior_fn,
'make_prior_fn': self._make_prior_fn,
'kl_weight': self._kl_weight,
'kl_use_exact': self._kl_use_exact,
'activation': self.activation,
'use_bias': self.use_bias,
}
return config
Is this something I should create a PR for to add to the repo instead of having this temporary fix?
@jedisom You can create a PR for that.
I already asked this here, but I think this is an issue that should get updated in TensorFlow Probability code.
I have a TensorFlow Probability model that is built similar to models described in this YouTube Video. Here's the code to build the model:
When I remove the TensorFlow Probability layer (1st layer) in the model, I can clone the model and copy its weights like this:
However, when the TensorFlow Probability layer is present I get this error:
I can see some information about how to deal with this error in this StackOverflow question, but in that question there's a custom-built transformer class that can be modified. I'm trying to use the
clone_model
function in keras, which I don't directly control. And, the error seems to be coming from the TFPDenseVariational
layer that doesn't overrideget_config
. Should theDenseVariational
class get updated to override theget_config
method? If not, how can I clone/duplicate a model, including its weights, if the model includes TensorFlow Probability layers as above?I'm using