Documentation | Getting Started | Reddit Post with more info
This library aims to be an allround toolkit for attaching, training, saving and loading of new heads for transformer models.
A new head could be:
On top of that, attaching multiple heads at once can make multi-task learning easy, making it possible to train very general models.
Install from pypi: pip install transformer-heads
.
Or, clone this repo and from the root of this repository:
pip install -e .
Create head configurations
head_config = HeadConfig(
name=f"imdb_head_3",
layer_hook=-3, # Attach at the output of the third-to-last transformer-block
in_size=hidden_size,
output_activation="linear",
pred_for_sequence=True,
loss_fct="cross_entropy",
num_outputs=2,
target="label" # The name of the ground-truth column in the dataset
)
Create a model with your head from a pretrained transformer model
model = load_headed(
LlamaForCausalLM,
"meta-llama/Llama-2-7b-hf",
head_configs=[heads_config],
)
Train you model using (for example) the simple to use huggingface Trainer interface:
trainer = Trainer(
model,
args=args,
train_dataset=imdb_dataset["train"],
data_collator=collator,
)
For a more in-depth introduction and a fully working example, check the linear probe notebook.
This repository contains multiple jupyter notebooks for a tutorial/illustration of how do do certain things with this library. Here is an overview of which notebook you should check out depending on the use you are interested in.
At the state of writing, only a subset of loss functions are supported out of the box. Check transformer_heads/constants.py for more up to date info.
However, it is not so hard to add/use different loss functions/models. You'll just need to add their respective information to loss_fct_map
and model_type_map
. Just import from transformer_heads.constants
. To add a loss function, add a mapping from string to torch class. To add a model add a mapping from model type to a 2 tuple out of attribute name of the base model in the Model Class and Base model class. That may sound confusing, but what that means is just the following:
from transformer_heads.constants import model_type_map, loss_fct_map
import torch.nn as nn
from transformers import MistralModel
loss_fct_map["bce"] = nn.BCELoss()
model_type_map["mistral"] = ("model",MistralModel)
One of the basic assumtions of my library is that there is a transformer class such as the LlamaForCausalLM class of huggingface that has an attribute pointing to a base model that outputs raw hidden state. If your transformers model is built up in a similar way, adding support may be as easy as adding an entry to the model_type_map with the name of the attribute and the class of the base model. You can either do that by importing from constants.py or by adding it directly and creating a pull request.