Open wulikai1993 opened 3 years ago
You should pass a loss instance into your task that has the reduction you want.
For example:
tfrs.tasks.Retrieval(
loss=tf.keras.losses.CategoricalCrossentropy(reduction="mean")
)
Sorry, I use reduction="mean"
the following error happened:
ValueError Traceback (most recent call last)
<ipython-input-13-05774e3029ba> in <module>
1 task = tfrs.tasks.Retrieval(
2 metrics=metrics,
----> 3 loss=tf.keras.losses.CategoricalCrossentropy(reduction="mean")
4 )
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __init__(self, from_logits, label_smoothing, reduction, name)
657 reduction=reduction,
658 from_logits=from_logits,
--> 659 label_smoothing=label_smoothing)
660
661
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __init__(self, fn, reduction, name, **kwargs)
233 **kwargs: The keyword arguments that are passed on to `fn`.
234 """
--> 235 super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
236 self.fn = fn
237 self._fn_kwargs = kwargs
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __init__(self, reduction, name)
97 name: Optional name for the op.
98 """
---> 99 losses_utils.ReductionV2.validate(reduction)
100 self.reduction = reduction
101 self.name = name
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/loss_reduction.py in validate(cls, key)
66 def validate(cls, key):
67 if key not in cls.all():
---> 68 raise ValueError('Invalid Reduction Key %s.' % key)
ValueError: Invalid Reduction Key mean.
The reduction
arg seems have no mean
value. https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy#args
You're right, I was writing this off the top of my head. You can use the reduction values from the API page you link.
Sorry, which value can calculate mean loss? I think there are all sum values.
It's tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
.
I tried SUM_OVER_BATCH_SIZE
, but the batch loss still dropped at the end of each epoch.
I used the basic_retrieval.ipynb, and added tensorboard:
log_dir = "retrieval_log/"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='batch')
model.fit(cached_train, epochs=100, callbacks=[tensorboard_callback])
Are the BATCH dimension matching evenly? You can test with drop_remainder=True
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch
The last batch dimension is smaller. I set drop_remainder=True
and the loss curve became normal.
But I'm confused because the loss should be the average value of each element in one batch, which has nothing to do with the batch dimension.
This might be a consequence of the in-batch softmax loss used for the Retrieval
task: it tries to pick the clicked item out of all the other items in the batch. If a batch is smaller, the task is easier, and so the loss may be lower.
When training the retrieval model, the loss will drop suddenly at the end step of each epoch in the tensorboard. In my opinion, this is because the last data batch is smaller than the normal batch size. And the loss is the sum loss in the retrieval task. (
tfrs.tasks.Retrieval
usetf.keras.losses.CategoricalCrossentropy
as default. ) So how to calculate mean loss instead of sum loss in retrieval task? Thanks in advance!