keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
52 stars 24 forks source link

Toggle aggregated batch losses with tf.keras.callbacks on_train_batch_end #393

Open leland-hepworth opened 1 year ago

leland-hepworth commented 1 year ago

System information

TensorFlow version (you are using): v2.10 (but applicable in any version >= 2.2) Are you willing to contribute it (Yes/No) : No

Describe the feature and the current behavior/state

In TensorFlow 1 and TensorFlow 2 version <= 2.1, loss is the actual loss of the current batch, while accuracy is the average accuracy of all batches up until the current batch within the current epoch (the same value that is printed by the progressbar of model.fit(...). This behavior is discussed in tensorflow issue #36400, and a gist can be found here.

In TensorFlow 2 version >= 2.2, both loss and accuracy are now averages. A gist can be found here.

The feature request is a toggle that determines whether aggregated or non-aggregated loss (and possibly other metrics) are provided to on_train_batch_end.

Will this change the current api? How?

Yes, although I'm not sure of the best way to go about it. A boolean can be added somewhere (possibly in model.fit(...) or tf.keras.callbacks.CallbackList(...)) to toggle whether aggregated or non-aggregated batch statistics are used. The default value should be to use aggregated batch statistics, because that is the current behavior in TF v2.10.

Who will benefit from this feature?

Anyone using the learning rate finder callback described in Cyclical Learning Rates for Training Neural Networks by Leslie N. Smith needs the non-aggregated loss. If average losses are used instead, early batches where the learning rate is too low will continue to impact the average loss even in later batches where learning rate has increased enough for the loss to start to decrease. This shifts the suggested range of learning rate values to the right, resulting in learning rates that are much higher than the optimal range, and leading to poorer performance than when the loss for only the current batch is used correctly. Also, if the learning rate search is trained on more than one epoch, there will be jumps in the loss at the start of each epoch when the averages are reset.

There are several online tutorials describing how to create a learning rate finder callback:

None of these make adjustments for the average losses, possibly because they were written while using a version of TensorFlow older than v2.2.

There are also several Stack Overflow questions on the topic of the using current loss instead of average loss:

Contributing

While I don't feel comfortable attempting to edit Keras code, I do have a workaround that could be added to any custom callback needing to calculate the current batch loss instead of using the provided average loss:

from tensorflow.keras.callbacks import Callback

class CustomCallback(Callback):
    ''' This callback converts the average loss (default behavior in TF>=2.2)
        into the loss for only the current batch.
    '''
    def on_epoch_begin(self, epoch, logs={}):
        self.previous_loss_sum = 0

    def on_train_batch_end(self, batch, logs={}):
        # calculate loss of current batch:
        current_loss_sum =  (batch + 1) * logs['loss']
        current_loss = current_loss_sum - self.previous_loss_sum
        self.previous_loss_sum = current_loss_sum

        # use current_loss:
        # ...

It isn't efficient because Keras code is calculating the average loss and then this code is calculating the raw loss for the current batch, but it can at least get the desired loss.

mihirparadkar commented 1 year ago

@nkovela1

rchao commented 1 year ago

Thanks for reaching out for the issue! This seems a nice to have feature where we can leverage help from the community. I'll make this as a community welcome item.

deepNeuralNick commented 1 year ago

I can take this if it's still available.

rchao commented 1 year ago

Thanks @deepNeuralNick! I think it would make sense for us to proceed with such support, but we should discuss what's the best way to expose such api to users before we start working on it. For example, having such argument in compile may add one too many argument to an api that's already overloaded. Do you have any suggestions?

deepNeuralNick commented 1 year ago

@rchao if you think it's not supposed to be added to the compile method then a custom callback could also be implemented like the one recommended already, although from a users perspective i think that loss should be aggregated the same way that accuracy is aggregated by default, and have a flag on compile that shows both of them on batch or aggregated. The callback is also not an obvious way for a user to handle this type of functionality.

rchao commented 1 year ago

Agreeing that callback is not an obvious way to handle this. Let's start with having a flag on compile then and we can discuss from there.

Thanks!

deepNeuralNick commented 1 year ago

@rchao I'll do it this week on my free time :) Can you also assign me on this issue? Thanks!

abinthomasonline commented 1 year ago

Will this affect on_epoch_end() logs? If not the model should keep track of an additional copy of metrics right?

leland-hepworth commented 1 year ago

on_epoch_end should probably be unaffected. The cyclic learning rate finder only uses the on_train_batch_end result.

If not the model should keep track of an additional copy of metrics right?

I don't know the details of how on_epoch_end calculates metrics, but this does give me an idea for an alternate solution. Instead of toggling whether loss in the on_train_batch_end logs represents the aggregated loss so far in that epoch (current behavior) or the non-aggregated raw loss for only the current batch, maybe both metrics could be stored in the logs. Everyone who needs the current aggregated behavior can continue to access loss as normal. But anyone needing the non-aggregated raw loss for only the current batch can access that with a different key in the logs dictionary, such as loss_current, loss_raw, loss_non_agregated, or loss_non_agg (or maybe an ordering like current_loss would be more appropriate).

Would including an new key/value pair in the logs cause anything to break?