huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.9k stars 26.51k forks source link

Enhance Generality of Trainer Class #29972

Open xjw-nlp opened 6 months ago

xjw-nlp commented 6 months ago

Feature request

In the current version of the transformers library's Trainer class, the model can only report a total loss during the training and evaluation stages. However, in practical applications, we often wish to observe the changes in different types of loss function values in multi-task learning. For example, we might add a Contrastive loss on top of the MLE loss. I suspect that adding the feature to record different losses could enhance the generality of the code.

Motivation

Multi-task learning is a common technique for improving model performance and generalization ability. In this case, in order to observe the effectiveness of different objective functions more finely, we need to record other losses apart from the total loss.

Your contribution

Yes, I can submit a PR.

ArthurZucker commented 6 months ago

Thanks for the request. I think the callbacks can help for a part and you can always overload the compute_loss funciton no?

ArthurZucker commented 6 months ago

an example: https://github.com/huggingface/transformers/blob/c1bd53e79833d2fea1ccfdcc96e8e833a1c02ec0/docs/source/en/tasks/knowledge_distillation_for_image_classification.md#L56-L94

xjw-nlp commented 6 months ago

an example:

https://github.com/huggingface/transformers/blob/c1bd53e79833d2fea1ccfdcc96e8e833a1c02ec0/docs/source/en/tasks/knowledge_distillation_for_image_classification.md#L56-L94

This case seems to pinpoint the current constraints of the Trainer class. Specifically, we are struggling to find an elegant way to separately log 'student_target_loss' and 'distillation_loss', as well as print them, along with other information, to the terminal. Currently, we observe that during the training phase, the codes to log the training information, including loss, is executed as follows:\

ArthurZucker commented 6 months ago

cc @muellerzr who knows better how to handle this than me

muellerzr commented 5 months ago

Re:

The question I would like to raise at the moment is whether it's possible to use different keywords in the logs variable to record different losses, rather than relying on a single loss keyword to document the total loss.

In its current state its not, as we create the logs dict there.

The best way would be to probably include an additional object of bits to get logged that a callback has access to (maybe AdvancedLoggingCallback).

_maybe_log_save_evaluate gets called after on_step_end, so your custom callback should store these values in that.

I'll make a quick mock-up in a second and PR, and we can iterate :)

muellerzr commented 5 months ago

Actually, off the top of my head there's not an "easy" solution to modify this behavior. Essentially we would need to come up with a decent solution for users to add their own custom bits to be logged, and save those in the callback, before finally getting them ready to be logged as part of logs.

The other alternative is to make logs a class attribute, but that does not seem like a good solution imo