Open ydennisy opened 1 year ago
@maciejkula sorry to bother - but I have seen you implemented other forms of loss in lightfm, would you be able to provide a little guidance how one could do the same for TFRS - also why does TFRS come with a single loss function out of the box, is this a design choice?
The Retrieval task accepts a loss function argument. You should be able to use your own loss function instead of the default softmax, as long as your happy accepting a num_queries x num_candidates
logit and label matrix.
For more complex changes you could implement your own Task
subclass, or even dispense with Task
entirely.
The in-batch softmax setup is extremely robust, and works very well in most cases. This is why we don't provide extensive examples of other losses for this task.
As an experiment implementing a custom loss function for a retrieval model, I implemented the following class, but I keep getting a "ValueError: No gradients provided for any variable” error when I fit the model.
import tensorflow as tf
class Top10CategoricalAccuracy(tf.keras.losses.Loss):
def __init__(
self,
name='top10_categorical_accuracy'):
super().__init__(name=name)
def call(
self,
y_true,
y_pred,
sample_weight = None):
y_pred = tf.keras.backend.softmax(y_pred)
top_10_accuracy = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=10)
loss = 1.0 - top_10_accuracy
return loss
Creation of the retrieval task looks like:
loss_calculator = TopKCategoricalAccuracy(from_logits = True, top_k = 10) # Experimental custom loss function
candidates = unique_candidate_ds.batch(128).map(lambda c: (c['item_id'], self.candidate_model(c)))
metrics = tfrs.metrics.FactorizedTopK(candidates = candidates)
batch_metrics = [tf.keras.metrics.AUC(from_logits = True, name = "retrieval_auc")]
self.task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
loss = loss_calculator,
metrics = metrics,
batch_metrics = batch_metrics)
This question is very broad and theoretical so apologies for that!
I would love to learn more about various loss functions which could be implemented in
tfrs
- when looking at other libs suchlightfm
, most will support a few loss functions.I feel a tutorial on alternative loss functions would be excellent for example: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/losses/triplet.py Using this function naively does not work - due to a data mismatch.
Thanks in advance!