pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.35k stars 441 forks source link

[RFC] Classification task fine-tuning #1464

Open SalmanMohammadi opened 2 months ago

SalmanMohammadi commented 2 months ago

There has been some community appetite for classification tasks #1249 #1124. Incidentally, due to the use of classification models for RLHF, we already have some of the necessary components to support classification tasks. I think we're not too far off of supporting this. Concretely:

In-progress

1) Land #1463 which will provide generic collation utils for classification datasets. 2) Add support for classification datasets (I've mostly completed this, will put a PR up soon)

TODO - if this sounds interesting to you and you'd like to help out here please don't hesitate to comment!

Add support for a classification loss

I don't think we need a new recipe for this task, we mainly just need to refactor the label-slicing logic in _loss_step. Instead, we want to grab the scores from the classification model like we do in the RLHF recipe.

cc @felipemello1 for some thoughts here on compatibility with chunked CE.

Huggingface models also provide a reference for classification loss calculation https://github.com/huggingface/transformers/blob/38d58a4427c7c5093dc7bde45613d2bb0a5dea2c/src/transformers/models/llama/modeling_llama.py#L1409.

This should enable fine-tuning with our already defined classification models, our classification datasets (coming soon!) and generic collation utils.

Testing this will be the most complex step here. I haven't found many existing examples to benchmark against, so if you're reading this and know of some, please chime in! At the very least, we'll need to verify correctness by seeing sensible loss scores and eval outputs.

Optional - Generalize classification models We currently only support Mistral and Llama classifier models. Model builders for these classifiers are only provided for binary classification (or regression) tasks. If you'd like to use another model for a classification task, we should discuss some sensible way to add generic support for converting an existing model to a classification model, without needing to define builders for each model.

Thoughts, comments, criticisms, appreciation, all welcome here.

rezadnayeri commented 2 months ago

one quick suggestion: one of the key items for the classification, is the raw Logits (i.e. unnormalized outputs of a model before applying a final activation function, such as softmax. They are the direct output of the last linear layer ), The logits are passed through a softmax function using torch.nn.functional.softmax to convert them into probabilities: softmax_probs = torch.nn.functional.softmax(torch.tensor(logits), dim=1).numpy()

basically, we not only need the predicated classification label, we also need the probability (model confidence) in predicting the label. If possible, please also output the Logits, not only the class labels. Thanks!

felipemello1 commented 2 months ago

I don't think we need a new recipe for this task, we mainly just need to refactor the label-slicing logic in _loss_step.

I wonder if we should move the label slicing logic to the dataloader. That way, the classification dataloader could handle it differently, without having to add if/else to the recipes

qqlabs commented 2 months ago

@SalmanMohammadi I got a model trained in a hacky way with this PR. Still working on a good way to use the trained model for inference. Interested in your thoughts on potential bugs + improving the implementation.

I added comments in the recipe file for key changes/open questions that I have.

rezadnayeri commented 2 months ago

@SalmanMohammadi and the team: any updates on this RFC? Wish I could help, but I have a limited knowledge to contribute. Looking forward to your update. Thanks!

funtion commented 1 month ago

Do you have plan to support multi-class classification tasks? The logic should be very similar to binary classification. It will be great if the num_classes param is configurable in the yaml.

qqlabs commented 1 month ago

My PR above is a proof of concept of multiclass classification - it just overfits during training right now. May need to play with the final output layer (maybe don't override the last layer and instead add another layer that maps to your num_classes). I'm working with just prompt completion for now and may revisit after I finish my current iterations.

rezadnayeri commented 1 month ago

hello, was wondering if this RFC is still being worked. thanks!

SalmanMohammadi commented 1 month ago

Hi @rezadnayeri. Sorry for the late update on this! Unfortunately I haven't had the bandwith to work on this in a while. Once I have another chance to look at this I can provide an update - I wish I could provide guarantees on when that would be.

rezadnayeri commented 1 month ago

hi @SalmanMohammadi
understood, thank you for your selfless service to all of us, appreciate it.