Closed seanpmorgan closed 5 years ago
Hi @seanpmorgan, from tensorflow testing meeting of 5th Feb, I liked the idea of saving useful parts from tf.contrib and add it to a add-on repo for tf-2.0 and would like to contribute to the same :). Can I contribute to this as I am not a gde and is this a good issue to start with
@tabshaikh Welcome, Tabish! I think, in all those functions, triplet_semihard_loss and lifted_struct_loss are easy to be implemented. I am in the process of moving triplet_semihard_loss (nearly finished), and I'll create a PR tomorrow if everything goes well (I'm on holiday).
I think lifted_struct_loss is a good starter issue, would you like to take a try?
@facaiy sure I would like to give lifted_struct_loss a try
Hello @seanpmorgan and @facaiy, I think I can take a look at the remaining losses. Is there anyone already responsible for these? If not, I'd like to give it a try. Thank you!
Welcome, @WindQAQ . I don't know anyone is working on them, so yes, just go head.
By the way, I think #25 is easier for new contributor. Perhaps you are interested on it as well. Anyway, thanks for your help. Just ping us (or join https://gitter.im/tensorflow/sig-addons) if you need any help :-)
Okay. It's also free for me to deal with image functions first. Thanks for your quick reply :smiley:
Hey guys, because all losses in addons must inherit from keras.losses.LossfunctionWrapper
, I have some API design issues here:
Since fn
, which is passed into LossFunctionWrapper, should be with two positional arguments, however, contrastive_loss
, npairs_loss
and npairs_loss_multilabel
implemented in contrib are with three arguments.
My initial thought is that (take contrastive_loss
for example):
embeddings_anchor
and embeddings_positive
can be stacked to one tensor, and we can split it manually in addons implementation. Thus, the interface should be like
def contrastive_loss(y_true, y_pred):
embeddings_anchor = y_pred[0] # suppose they are stacked along with axis 0
embeddings_positive = y_pred[1]
# and so on
class ContrastiveLoss(LossFunctionWrapper):
def __init__(self, name):
super(ContrastiveLoss, self).__init__(fn=contrastive_loss, name=name)
One drawback here is that users should manipulate this stacked tensor on their own, which is not so intuitive though. Any feedback or thought would be really appreciated. Thanks!
@martinwicke @fchollet Could you give us some advice for @WindQAQ 's question?
The Loss
class fundamentally assumes that a signature with 3 tensor arguments, y_true
, y_pred
, and sample_weight
. We can't go around it.
I think the right pattern here isn't a Loss
subclass but a Layer
subclass. Here's how it would work.
class ContrastiveLossLayer(keras.layers.Layer):
def __init__(self, name=None):
super(ContrastiveLossLayer, self).__init__(name=name)
def __call__(self, y_true, positive, negative):
# Pack all arguments in the first argument (optional, to expose a better __call__ signature)
return super(LogisticHead, self).__call__([y_true, positive, negative])
def call(self, inputs):
y_true, positive, negative = inputs[0], inputs[1], inputs[2]
loss = ...
inference_time_predictions = ...
self.add_loss(loss)
return inference_time_predictions
Then to use it, you just insert it in your model as you would do for a regular layer (see the logistic loss layer example I linked).
Thanks for the suggestion! I do think this pattern is more flexible and beneficial to future extension.
@seanpmorgan and @facaiy, do you think addons can adopt this kind of design pattern instead of inheritance from LossFunctionWrapper
as a unified interface. Or we can inherit from LossFunctionWrapper
when it's possible but inherit from Layer
when it comes to ContrastiveLoss
case. Many thanks!
I'm fine with inheriting from Layer
, but I'm not sure which module is appropriate for this kind of loss, tfa.layer
or tfa.loss
. @WindQAQ Could you pick one loss, say ContrastiveLossLayer, and write an example to show us how to use it with keras model?
@facaiy It should be something like siamese network. ContrastiveLossLayer
can return other values such as euclidean distance (similarity) between two embeddings instead of loss. (Just a rough script here).
import numpy as np
import tensorflow as tf
@tf.function
def contrastive_loss(y_true, embeddings_anchor, embeddings_positive,
margin=1.0):
distances = tf.math.sqrt(
tf.math.reduce_sum(
tf.math.squared_difference(
embeddings_anchor, embeddings_positive),
1))
return tf.math.reduce_mean(
tf.cast(y_true, tf.dtypes.float32) * tf.math.square(distances) +
(1. - tf.cast(y_true, tf.dtypes.float32)) *
tf.math.square(tf.math.maximum(margin - distances, 0.)),
name='contrastive_loss')
class ContrastiveLossLayer(tf.keras.layers.Layer):
def __init__(self, margin=1.0, name=None):
super(ContrastiveLossLayer, self).__init__(name=name)
self._margin = margin
def __call__(self, y_true, embeddings_anchor, embeddings_positive):
return super(ContrastiveLossLayer, self).__call__([y_true, embeddings_anchor, embeddings_positive])
def call(self, inputs):
loss = contrastive_loss(*inputs, margin=self._margin)
self.add_loss(loss)
return loss
class L2Normalization(tf.keras.layers.Layer):
def __init__(self, name=None):
super(L2Normalization, self).__init__(name=name)
def call(self, inputs):
return tf.math.l2_normalize(inputs, axis=1)
def create_base_model(input_shape=(28, 28)):
input = tf.keras.layers.Input(shape=input_shape)
x = tf.keras.layers.Flatten()(input)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = L2Normalization()(x)
return tf.keras.models.Model(input, x)
base_model = create_base_model()
input_a = tf.keras.layers.Input(shape=(28, 28), name="input_a")
input_b = tf.keras.layers.Input(shape=(28, 28), name="input_b")
labels = tf.keras.layers.Input(shape=(1,), name="labels")
output_a = base_model(input_a)
output_b = base_model(input_b)
outputs = ContrastiveLossLayer()(labels, output_a, output_b)
model = tf.keras.models.Model([input_a, input_b, labels], outputs=outputs)
model.compile(tf.keras.optimizers.Adam(1e-3))
model.summary()
fake_data = {
'input_a': np.random.rand(1000, 28, 28),
'input_b': np.random.rand(1000, 28, 28),
'labels': np.random.randint(0, 2, size=(1000, 1)),
}
model.fit(fake_data, epochs=3)
@fchollet Thanks for your advice! I note that add_loss
Note that add_loss is not supported when executing eagerly.
Is there a plan to support eager mode in the near future?
@WindQAQ Thank you, Tzu-Wei. I'm kind of curious if the solution (use layer.add_loss
and model.compile(loss=None)
will works as we expected. Hence it could be much helpful if we could make a demo. If ContrastiveLossLayer
is not a easy task, would you minding creating a MeanSquaredErrorLayer based on MeanSquaredError ?
@seanpmorgan Sean, at the first glance at example provided by @WindQAQ , it seems that we should put ContrastiveLossLayer
in the tfa.layer
module. What do you think?
@facaiy Scripts are here. I use MNIST and sparse_categorical_crossentropy
as an example. Also, while I was implementing the scripts, I figured out some drawbacks if it's inherited from Layer
:
self.add_metric
in Layer
to trace some useful information such as accuracy, it seems that it's impossible for users to use model.compile(..., metrics=['accuracy'])
to achieve the goal.Therefore, I think inheritance from Layer
is much more flexible for developers, but it's not an user-friendly API.
Nich work, Tzu-Wei, and thank you very much!
The labels should be treated as one of the inputs.
Yes, that's the drawback what I thought in the morning. It looks kind of counter-intuitive for user.
I have a question after skimming the paper . Can we split ContrastiveLoss
to two parts: 1, a layer to calculate distance
between embeddings; 2, a loss contrastive_loss(labels, distances)
?
I have a question after skimming the paper . Can we split
ContrastiveLoss
to two parts: 1, a layer to calculatedistance
between embeddings; 2, a losscontrastive_loss(labels, distances)
?
I think this is probably the implementation in keras example. In this case, ContrastiveLoss
can exactly take two signature arguments after embeddings_anchor
and embeddings_positive
are pre-computed as distances. Not sure if this will meet requirements for addons' community.
Great, looks good to me. @seanpmorgan Sean, what do you think?
Similarly, in the cases of npairs_loss
and npairs_loss_multilabel
, I suppose that if leaving L2 regularization terms behind, we can split it into two functions, say
compute_similarity_matrix(embeddings_anchor, embeddings_positive)
npairs_loss(labels, similarity_matrix)
.We finish it finally, thank all for your support
Per the RFC, we need to move metric_losses from contrib to addons:
This will involve inheriting from base Keras Loss, modifying the code to match those APIs, and modifying test cases to run in TF2.