tensorflow / swift-models

Models and examples built with Swift for TensorFlow
Apache License 2.0
647 stars 147 forks source link

Complex Loss Functions Inside TrainingLoop #720

Open xanderdunn opened 3 years ago

xanderdunn commented 3 years ago

Thanks to @xihui-wu's talk earlier today, I learned about the TrainingLoop struct. I had essentially replicated this functionality in a messier way in my code, so I'm looking at it to see if I could replace my train loop with this cleaner implementation. I believe the only issue I might face is with respect to the loss function.

The current loss function takes as parameters only the model's output and the target: public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor<Float> from here. This covers a huge majority of supervised training situations, but there are situations where we might want more complicated loss functions. For example, how might we mask the output and the target for each sample when we calculate the loss, as done in this paper:

at each time step the model tries to predict the full, uncorrupted input vectors xt; however, only the predictions on the masked values are considered in the Mean Squared Error loss.

Another situation that comes to mind is a loss that requires some third, external set of values. Perhaps this is an RL agent whose current loss is a function of recent past losses. Another example could also be a risk-adjusted metric where "risk" depends on some external value that is not static.

Is my reading of the code correct that these types of loss functions are not currently supported? If so, could the protocol be reasonably modified to optionally support such complex loss functions?

Many thanks!

xihui-wu commented 3 years ago

@xanderdunn thanks for your post! TrainingLoop is currently in iterations to develop to cover more and more use cases.

To answer you first question in supervised learning scenario - generally speaking, if you want to customize loss function by manipulating on output and label, you can write your loss function like what BERT-CoLA does here. For your mask case, you might be able to do similarly, or do sth inside the model like BERT attention does here ?

Secondly, yes, when the loss function depends on some training-state-depended value, some work are do needed to allow for it. Specially RL, it is a different world, we are considering providing different TrainingLoop variations for it.

cc @BradLarson

xanderdunn commented 3 years ago

Thanks @xihui-wu. Writing custom loss functions is straightforward, but the challenging part is a training loop that is compatible with loss functions that take more than just logits and labels as parameters. In the linked example the labels have type Tensor<Int32>. Could this instead be a struct that contains multiple tensors? Such a struct could contain both the labels and the masks, for example.

Your BERT attention link points to the same link as the BERT-CoLA example, but I think you're referring to this attention mask that is applied here to the attention scores? This masks the inputs into the model, but it doesn't mask the outputs and the targets for calculating the loss as done in the Transformer representation learning paper I linked above.

xihui-wu commented 3 years ago

Thanks for correcting the link. You are welcome to try if making it into a struct together with some changes in TrainingLoop works. Again our current TrainingLoop implementation isn't in a final form, and more flexibility may be needed for RL and other applications! At some point we might look at having composable training loop pieces so that loops could be more flexible, but that be some ways off. We're open to having the current training loop be expanded to fit other use cases.