tensorflow / recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Apache License 2.0
1.83k stars 274 forks source link

How to calculate mean loss instead of sum loss in retrieval task? #177

Open wulikai1993 opened 3 years ago

wulikai1993 commented 3 years ago

When training the retrieval model, the loss will drop suddenly at the end step of each epoch in the tensorboard. In my opinion, this is because the last data batch is smaller than the normal batch size. And the loss is the sum loss in the retrieval task. (tfrs.tasks.Retrieval use tf.keras.losses.CategoricalCrossentropy as default. ) So how to calculate mean loss instead of sum loss in retrieval task? Thanks in advance!

maciejkula commented 3 years ago

You should pass a loss instance into your task that has the reduction you want.

For example:

tfrs.tasks.Retrieval(
  loss=tf.keras.losses.CategoricalCrossentropy(reduction="mean")
)
wulikai1993 commented 3 years ago

Sorry, I use reduction="mean" the following error happened:

ValueError                                Traceback (most recent call last)
<ipython-input-13-05774e3029ba> in <module>
      1 task = tfrs.tasks.Retrieval(
      2   metrics=metrics,
----> 3   loss=tf.keras.losses.CategoricalCrossentropy(reduction="mean")
      4 )

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __init__(self, from_logits, label_smoothing, reduction, name)
    657         reduction=reduction,
    658         from_logits=from_logits,
--> 659         label_smoothing=label_smoothing)
    660 
    661 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __init__(self, fn, reduction, name, **kwargs)
    233       **kwargs: The keyword arguments that are passed on to `fn`.
    234     """
--> 235     super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
    236     self.fn = fn
    237     self._fn_kwargs = kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __init__(self, reduction, name)
     97       name: Optional name for the op.
     98     """
---> 99     losses_utils.ReductionV2.validate(reduction)
    100     self.reduction = reduction
    101     self.name = name

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/loss_reduction.py in validate(cls, key)
     66   def validate(cls, key):
     67     if key not in cls.all():
---> 68       raise ValueError('Invalid Reduction Key %s.' % key)

ValueError: Invalid Reduction Key mean.
wulikai1993 commented 3 years ago

The reduction arg seems have no mean value. https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy#args

maciejkula commented 3 years ago

You're right, I was writing this off the top of my head. You can use the reduction values from the API page you link.

wulikai1993 commented 3 years ago

Sorry, which value can calculate mean loss? I think there are all sum values.

maciejkula commented 3 years ago

It's tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE.

wulikai1993 commented 3 years ago

I tried SUM_OVER_BATCH_SIZE, but the batch loss still dropped at the end of each epoch. I used the basic_retrieval.ipynb, and added tensorboard:

log_dir = "retrieval_log/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='batch')
model.fit(cached_train, epochs=100, callbacks=[tensorboard_callback])

image

Flipper-afk commented 3 years ago

Are the BATCH dimension matching evenly? You can test with drop_remainder=True https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch

wulikai1993 commented 3 years ago

The last batch dimension is smaller. I set drop_remainder=True and the loss curve became normal. But I'm confused because the loss should be the average value of each element in one batch, which has nothing to do with the batch dimension.

maciejkula commented 3 years ago

This might be a consequence of the in-batch softmax loss used for the Retrieval task: it tries to pick the clicked item out of all the other items in the batch. If a batch is smaller, the task is easier, and so the loss may be lower.