tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

NotImplementedError: Layer DenseVariational has arguments in `__init__` and therefore must override `get_config`. #1627

Open jedisom opened 2 years ago

jedisom commented 2 years ago

I already asked this here, but I think this is an issue that should get updated in TensorFlow Probability code.

I have a TensorFlow Probability model that is built similar to models described in this YouTube Video. Here's the code to build the model:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from typing import Any

def posterior_mean_field(kernel_size: int, bias_size: int, dtype: Any) -> tf.keras.Model:
    n = kernel_size + bias_size
    c = np.log(np.expm1(1.))
    return tf.keras.Sequential([
        tfp.layers.VariableLayer(2 * n, dtype=dtype),
        tfp.layers.DistributionLambda(lambda t: tfd.Independent(tfd.Normal(loc=t[..., :n],
                                                                           scale=1e-5 + tf.nn.softplus(c + t[..., n:])),
                                                                reinterpreted_batch_ndims=1)),
    ])

def prior_trainable(kernel_size: int, bias_size: int, dtype: Any) -> tf.keras.Model:
    n = kernel_size + bias_size
    return tf.keras.Sequential([
        tfp.layers.VariableLayer(n, dtype=dtype),
        tfp.layers.DistributionLambda(lambda t: tfd.Independent(
            tfd.Normal(loc=t, scale=1),
            reinterpreted_batch_ndims=1)),
    ])

def build_model():
    model = keras.Sequential([
        tfp.layers.DenseVariational(64, activation='relu', input_shape=[10],
                                    make_posterior_fn=posterior_mean_field,
                                    make_prior_fn=prior_trainable),
        layers.Dense(64, activation='relu'),
        layers.Dense(1),
    ])
    optimizer = tf.keras.optimizers.RMSprop(0.001)
    model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
    return model

model = build_model()
model.build((3, 10))

When I remove the TensorFlow Probability layer (1st layer) in the model, I can clone the model and copy its weights like this:

import copy
from tensorflow.keras.models import clone_model
model_weights = copy.deepcopy(model.get_weights())
model_copy = clone_model(model)
model_copy.set_weights(model_weights)

However, when the TensorFlow Probability layer is present I get this error:

Traceback (most recent call last):
  File "/Users/jisom/opt/miniconda3/envs/ic-hours/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-90d09fdd3673>", line 211, in <cell line: 211>
    model_copy = clone_model(model)
  File "/Users/jisom/opt/miniconda3/envs/ic-hours/lib/python3.8/site-packages/keras/models.py", line 448, in clone_model
    return _clone_sequential_model(
  File "/Users/jisom/opt/miniconda3/envs/ic-hours/lib/python3.8/site-packages/keras/models.py", line 326, in _clone_sequential_model
    if isinstance(layer, InputLayer) else layer_fn(layer))
  File "/Users/jisom/opt/miniconda3/envs/ic-hours/lib/python3.8/site-packages/keras/models.py", line 56, in _clone_layer
    return layer.__class__.from_config(layer.get_config())
  File "/Users/jisom/opt/miniconda3/envs/ic-hours/lib/python3.8/site-packages/keras/engine/base_layer.py", line 727, in get_config
    raise NotImplementedError('Layer %s has arguments in `__init__` and '
NotImplementedError: Layer DenseVariational has arguments in `__init__` and therefore must override `get_config`.

I can see some information about how to deal with this error in this StackOverflow question, but in that question there's a custom-built transformer class that can be modified. I'm trying to use the clone_model function in keras, which I don't directly control. And, the error seems to be coming from the TFP DenseVariational layer that doesn't override get_config. Should the DenseVariational class get updated to override the get_config method? If not, how can I clone/duplicate a model, including its weights, if the model includes TensorFlow Probability layers as above?

I'm using

jedisom commented 2 years ago

I think I found a workaround/fix for this. I wrote a wrapper/sub-class for tfp.layers.DenseVariational that includes the get_config method and now the clone_model call in the code above seems to work if you replace tfp.layers.DenseVariational with DenseVariationalFix.

import tensorflow_probability as tfp

class DenseVariationalFix(tfp.layers.DenseVariational):
    def __init__(self,
                 units,
                 make_posterior_fn,
                 make_prior_fn,
                 kl_weight=None,
                 kl_use_exact=False,
                 activation=None,
                 use_bias=True,
                 activity_regularizer=None,
                 **kwargs):
        super(DenseVariationalFix, self).__init__(
            units=units,
            make_posterior_fn=make_posterior_fn,
            make_prior_fn=make_prior_fn,
            kl_weight=kl_weight,
            kl_use_exact=kl_use_exact,
            activation=activation,
            use_bias=use_bias,
            activity_regularizer=activity_regularizer,
            **kwargs)
        self._kl_weight = kl_weight
        self._kl_use_exact = kl_use_exact

    def get_config(self):
        config ={
            'units': self.units,
            'make_posterior_fn': self._make_posterior_fn,
            'make_prior_fn': self._make_prior_fn,
            'kl_weight': self._kl_weight,
            'kl_use_exact': self._kl_use_exact,
            'activation': self.activation,
            'use_bias': self.use_bias,
        }
        return config

Is this something I should create a PR for to add to the repo instead of having this temporary fix?

Frightera commented 1 year ago

@jedisom You can create a PR for that.