chrisociepa / allamo

Simple, hackable and fast implementation for training/finetuning medium-sized LLaMA-based models
MIT License
143 stars 14 forks source link

text classification? #4

Closed arpitest closed 1 year ago

arpitest commented 1 year ago

I've trained a small model for hungarian language for 5 days, text generation is working well. Is it possible to use this model for text classification too? (after finetraining on 100k text+class pairs dataset somehow) Where/how can be the classification hidden layer added/connected to this model?

chrisociepa commented 1 year ago

Take a look at this tutorial from HF. I see the following steps to do:

  1. fine-tune the model with pairs: text (input) and class (output token)
  2. inference:
    with torch.no_grad():
    logits, _ = model(**inputs)
    predicted_class_id = logits.argmax().item()

Another way could be using embedddings (model.embeddings(inputs) produces embbedings) and training a classification NN at top of them (instead of using model.lm_head that producing logits).

You can also convert your model to HF using export_to_hf.py and then fine-tune using their trainers and methods.

arpitest commented 1 year ago

thank you! i didnt notice the generate_embeddings(text) in sample.py earlier... that output can be feed into my classifier NN layers.

also the HF-export is a good idea, already converted the model and was able to load using transformers' from_pretrained(), but i'd like to implement myself...