keras-team / keras

Deep Learning for humans
Apache License 2.0
61.98k stars 19.46k forks source link

Save custom loss functions with external parameters. #7993

Closed Junyoungpark closed 7 years ago

Junyoungpark commented 7 years ago


I'm currently working on Mixture Density Networks. Consequently, I need to implement a custom negative log-likelihood loss function. To working within Keras training module, I slightly tricked my custom loss function as follows:

def GMMloss(kernelDim, numMixture, eps = 1e-8):
    def nll(y_true, params):
        mu, sigma, pi = tf.split(params, [kernelDim * numMixture, kernelDim * numMixture, numMixture], 1)
        if kernelDim == 1:            
            loss = tf.contrib.distributions.Normal(loc=mu, scale=sigma).prob(y_true)
            loss = tf.multiply(loss, pi)
            loss = K.sum(loss, axis=1, keepdims=True)
            loss = -K.log(loss + eps)
            return K.mean(loss)
        if KernelDim >= 2:
            raise NotImplementedError
    return nll

However, when it comes to save and LOAD model, I was not able to load the saved model properly.

from keras.models import load_model'test1.h5')
model2 = load_model('test1.h5', custom_objects={"FeedForwardAttention":FeedForwardAttention,
                                               "GMMloss": GMMloss(kernelDim, numMixture)})

with this error ValueError: Unknown loss function:nll . How can I pass the external parameters to the model properly?

Thanks in advance Junyoung Park

ADD Stack trace

ValueError                                Traceback (most recent call last)
<ipython-input-12-5d55702e2df6> in <module>()
      4 model2 = load_model('test1.h5', custom_objects={"FeedForwardAttention":FeedForwardAttention,
      5                                                "gaussianMixture":gaussianMixture,
----> 6                                                "GMMloss": GMMloss(kernelDim, numMixture, eps=1e-8)})
      7                                                #"nll": nll})

/home/jyp/tensorflow/local/lib/python2.7/site-packages/keras/models.pyc in load_model(filepath, custom_objects, compile)
    268                       metrics=metrics,
    269                       loss_weights=loss_weights,
--> 270                       sample_weight_mode=sample_weight_mode)
    272         # Set optimizer weights.

/home/jyp/tensorflow/local/lib/python2.7/site-packages/keras/engine/training.pyc in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs)
    654             loss_functions = [losses.get(l) for l in loss]
    655         else:
--> 656             loss_function = losses.get(loss)
    657             loss_functions = [loss_function for _ in range(len(self.outputs))]
    658         self.loss_functions = loss_functions

/home/jyp/tensorflow/local/lib/python2.7/site-packages/keras/losses.pyc in get(identifier)
    100     if isinstance(identifier, six.string_types):
    101         identifier = str(identifier)
--> 102         return deserialize(identifier)
    103     elif callable(identifier):
    104         return identifier

/home/jyp/tensorflow/local/lib/python2.7/site-packages/keras/losses.pyc in deserialize(name, custom_objects)
     92                                     module_objects=globals(),
     93                                     custom_objects=custom_objects,
---> 94                                     printable_module_name='loss function')

/home/jyp/tensorflow/local/lib/python2.7/site-packages/keras/utils/generic_utils.pyc in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    157             if fn is None:
    158                 raise ValueError('Unknown ' + printable_module_name +
--> 159                                  ':' + function_name)
    160         return fn
    161     else:

ValueError: Unknown loss function:nll
Junyoungpark commented 7 years ago

I've found a hacky solution. Keras cares and saves name of custom objects when loading custom objects. Therefore, by passing "nll" as a key rather than "GMMloss", I was able to solve the problem.

from keras.models import load_model'test1.h5')
model2 = load_model('test1.h5', custom_objects={"FeedForwardAttention":FeedForwardAttention,
                                               "nll": GMMloss(kernelDim, numMixture)})
BehnamTaki commented 5 years ago

what is the kerneldim?