pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.94k stars 357 forks source link

[RFC] Classification task fine-tuning #1464

Open SalmanMohammadi opened 2 weeks ago

SalmanMohammadi commented 2 weeks ago

There has been some community appetite for classification tasks #1249 #1124. Incidentally, due to the use of classification models for RLHF, we already have some of the necessary components to support classification tasks. I think we're not too far off of supporting this. Concretely:

In-progress

1) Land #1463 which will provide generic collation utils for classification datasets. 2) Add support for classification datasets (I've mostly completed this, will put a PR up soon)

TODO - if this sounds interesting to you and you'd like to help out here please don't hesitate to comment!

Add support for a classification loss

I don't think we need a new recipe for this task, we mainly just need to refactor the label-slicing logic in _loss_step. Instead, we want to grab the scores from the classification model like we do in the RLHF recipe.

cc @felipemello1 for some thoughts here on compatibility with chunked CE.

Huggingface models also provide a reference for classification loss calculation https://github.com/huggingface/transformers/blob/38d58a4427c7c5093dc7bde45613d2bb0a5dea2c/src/transformers/models/llama/modeling_llama.py#L1409.

This should enable fine-tuning with our already defined classification models, our classification datasets (coming soon!) and generic collation utils.

Testing this will be the most complex step here. I haven't found many existing examples to benchmark against, so if you're reading this and know of some, please chime in! At the very least, we'll need to verify correctness by seeing sensible loss scores and eval outputs.

Optional - Generalize classification models We currently only support Mistral and Llama classifier models. Model builders for these classifiers are only provided for binary classification (or regression) tasks. If you'd like to use another model for a classification task, we should discuss some sensible way to add generic support for converting an existing model to a classification model, without needing to define builders for each model.

Thoughts, comments, criticisms, appreciation, all welcome here.

rezadnayeri commented 2 weeks ago

one quick suggestion: one of the key items for the classification, is the raw Logits (i.e. unnormalized outputs of a model before applying a final activation function, such as softmax. They are the direct output of the last linear layer ), The logits are passed through a softmax function using torch.nn.functional.softmax to convert them into probabilities: softmax_probs = torch.nn.functional.softmax(torch.tensor(logits), dim=1).numpy()

basically, we not only need the predicated classification label, we also need the probability (model confidence) in predicting the label. If possible, please also output the Logits, not only the class labels. Thanks!

felipemello1 commented 1 week ago

I don't think we need a new recipe for this task, we mainly just need to refactor the label-slicing logic in _loss_step.

I wonder if we should move the label slicing logic to the dataloader. That way, the classification dataloader could handle it differently, without having to add if/else to the recipes

qqlabs commented 1 week ago

@SalmanMohammadi I got a model trained in a hacky way with this PR. Still working on a good way to use the trained model for inference. Interested in your thoughts on potential bugs + improving the implementation.

I added comments in the recipe file for key changes/open questions that I have.