huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.23k stars 220 forks source link

Discussion on refactoring #238

Closed tomaarsen closed 11 months ago

tomaarsen commented 1 year ago

Hello!

Why might we want to refactor?

I've opened this issue is a response to several issues that SetFit is currently tackling:

  1. 179, i.e. the desire to implement a TrainingArguments Γ  la πŸ€— Transformers and PyTorch Lightning.

  2. Difficulties with different heads as experienced (among other places) in #207 and #212. To be concrete, this project relies on if isinstance(..., torch.Tensor), if isinstance(..., nn.Module), if isinstance(..., LogisticRegression), etc. way too much.
  3. With the introduction of the differentiable head, training is no longer simply trainer.train(): to use the differentiable head correctly, you must trainer.freeze() the head, then trainer.train() for the body, then trainer.unfreeze(...) the head (but maybe keep the body frozen?), and then trainer.train(...) for the head, with a whole host of different parameters.
  4. use_differentiable_head=True is too vague, not well enough documented, and not flexible enough if we want to ever have support for more than 2 heads.
  5. Passing parameters for the head is also unintuitive, as it currently needs to be done like head_params={"out_features": num_classes}.

A natural consequence of issue 2 is that implementing a new head is quite problematic: it requires editing many different files, and may introduce bugs in other contexts (e.g. when using different heads). One of the major headaches is a type mistmatch between the body output and the head input, which has already caused bugs (#237, #207).

Issue 3 is quite noteworthy, as it shows us that different we may want to use different hyperparameters for step 1 (embedding finetuning) than for step 2 (classifier training) in the SetFit training process.

Goals

With a refactor, I hope to achieve several primary goals:

  1. Ensure a familiar, intuitive training scheme involving Trainer and TrainingArguments classes.
  2. Additionally, ensure that the model can also be trained without a TrainingArguments class. In short, a fit method of a component (body or head) must not accept a TrainingArguments instance as a parameter, but merely the hyperparameters like learning_rate, etc.
  3. Ensure that the entire SetFit model can be trained with a single trainer.train(), regardless of the head chosen.
  4. Ensure separation of issues by requiring that each of the heads are moved into separate files and implement a fixed interface. This interface requires the use of torch.Tensor to avoid type issues.
  5. Introduce a head parameter to SetFitModel.from_pretrained where users can provide implementations of the aforementioned head interface. This includes our provided logistic regression and linear heads, but also custom heads that users can come up with.

Structure

I imagine one of two structures: a simple one, with an interface only for the SetFit head, and a more complex one, where a fixed interface is introduced for the body, too.

Simple structure ``` └───setfit β”‚ data.py β”‚ integrations.py β”‚ logging.py β”‚ losses.py β”‚ pipeline.py β”‚ trainer.py β”‚ trainer_distillation.py β”‚ training_args.py β”‚ utils.py β”‚ __init__.py β”‚ β”œβ”€β”€β”€components β”‚ β”‚ modeling.py # Contains SetFitModel β”‚ β”‚ β”‚ β”œβ”€β”€β”€body β”‚ β”‚ model.py β”‚ β”‚ training_args.py β”‚ β”‚ utils.py β”‚ β”‚ β”‚ └───head β”‚ β”‚ api.py β”‚ β”‚ training_args.py β”‚ β”‚ β”‚ β”œβ”€β”€β”€linear β”‚ β”‚ model.py β”‚ β”‚ training_args.py β”‚ β”‚ β”‚ β”œβ”€β”€β”€logistic β”‚ β”‚ model.py β”‚ β”‚ training_args.py β”‚ β”‚ β”‚ └───sklearn β”‚ model.py β”‚ training_args.py β”‚ └───exporters onnx.py utils.py __init__.py ```
Complex structure ``` └───setfit β”‚ data.py β”‚ integrations.py β”‚ logging.py β”‚ losses.py β”‚ pipeline.py β”‚ trainer.py β”‚ trainer_distillation.py β”‚ training_args.py β”‚ utils.py β”‚ __init__.py β”‚ β”œβ”€β”€β”€components β”‚ β”‚ api.py # Contains SetFitI, the superclass of SetFitHeadI and SetFitBodyI β”‚ β”‚ modeling.py # Contains SetFitModel β”‚ β”‚ β”‚ β”œβ”€β”€β”€body β”‚ β”‚ β”‚ api.py # Contains SetFitBodyI(SetFitI) β”‚ β”‚ β”‚ β”‚ β”‚ └───sentence_transformer β”‚ β”‚ model.py # Contains SetFitSTBody(SetFitBodyI) β”‚ β”‚ training_args.py # Provides different hyperparameter defaults for SetFitSTBody β”‚ β”‚ utils.py β”‚ β”‚ β”‚ └───head β”‚ β”‚ api.py # Contains SetFitHeadI(SetFitI) β”‚ β”‚ β”‚ β”œβ”€β”€β”€linear β”‚ β”‚ model.py # Contains SetFitLinearHead(SetFitHeadI) β”‚ β”‚ training_args.py # Provides different hyperparameter defaults for SetFitLinearHead β”‚ β”‚ β”‚ β”œβ”€β”€β”€logistic β”‚ β”‚ model.py # Contains SetFitLogisticHead(SetFitSklearnHead) β”‚ β”‚ training_args.py # Provides different hyperparameter defaults for SetFitLogisticHead β”‚ β”‚ β”‚ └───sklearn β”‚ model.py # Contains SetFitSklearnHead(SetFitHeadI) β”‚ training_args.py # Provides different hyperparameter defaults for SetFitSklearnHead β”‚ └───exporters onnx.py utils.py __init__.py ```

Training Arguments

Both of these structures have files for training_args.py, which contains a dataclass with various hyperparameters. Each of the more nested training_args.py files have subclasses of the TrainingArguments where defaults are overridden to defaults that would work appropriately for that component and training phase. The TrainingArguments class may look like so:

from __future__ import annotations

from abc import ABC, abstractclassmethod
from dataclasses import dataclass, fields

@dataclass
class TrainingArguments(ABC):
    learning_rate: float
    ...

    @abstractclassmethod
    def embeddings_default(cls) -> TrainingArguments:
        raise NotImplementedError()

    @abstractclassmethod
    def classifier_default(cls) -> TrainingArguments:
        raise NotImplementedError()

    def to_dict(self):
        # filter out fields that are defined as field(init=False)
        return {field.name: getattr(self, field.name) for field in fields(self) if field.init}

Crucially, we may need up to 3 different TrainingArgument instances for the head and body in all situations in order to deal with the issue that training arguments for the body may differ for step 1 and step 2 of training SetFit:

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    embedding_body_args=TrainingArguments(...),
    classifier_body_args=TrainingArguments(...),
    classifier_head_args=TrainingArguments(...)
)
trainer.train()

Note that each of these arguments are optional. If one of the training arguments is not supplied, then we can use the default arguments for each of the components (body, head) for that training step (step 1: embeddings, step 2: classifier) to get the right hyperparameters. This is best described with an example of the new proposed Trainer.train method:

Trainer.train method ```python def train( self, embedding_body_args: Optional[BodyTrainingArguments] = None, classifier_body_args: Optional[BodyTrainingArguments] = None, classifier_head_args: Optional[HeadTrainingArguments] = None, ): # Simple example excluding column_mapping, hyperparameter tuning trials, etc. x_train: List[str] = train_dataset["text"] y_train: List[int] = train_dataset["label"] self.train_embeddings(x_train, y_train, embedding_body_args) self.train_classifier(x_train, y_train, classifier_body_args, classifier_head_args) def train_embeddings( self, x_train: List[str], y_train: List[int], embedding_body_args: Optional[BodyTrainingArguments] = None ): embedding_body_args = ( embedding_body_args or self.embedding_body_args or self.model.body.args_class.embeddings_default() ) self.model.body.fit(x_train, y_train, **embedding_body_args.to_dict()) def train_classifier( self, x_train: List[str], y_train: List[int], classifier_body_args: Optional[BodyTrainingArguments] = None, classifier_head_args: Optional[HeadTrainingArguments] = None, ): classifier_body_args = ( classifier_body_args or self.classifier_body_args or self.model.body.args_class.classifier_default() ) classifier_head_args = ( classifier_head_args or self.classifier_head_args or self.model.head.args_class.classifier_default() ) # If we don't train end-to-end, then simply get the embeddings of the training input, # and fit the head using it if not classifier_head_args.end_to_end: embeddings = self.model.encode(x_train) self.model.head.fit(embeddings, y_train) else: # Otherwise, fit the model itself, which implements a PyTorch end-to-end training loop. # The body_learning_rate is passed from the classifier body arguments, and for the rest # of the parameters we use the classifier head arguments. self.model.fit( x_train, y_train, body_learning_rate=classifier_body_args.learning_rate, **classifier_head_args.to_dict(), ) ```

Note here that the training arguments are provided by unpacking a dictionary representation of the hyperparameters.

Head Interface

To tackle issue 2, I propose an interface which all SetFit heads must follow. A potential implementation might be:

Example head interface (or abstract class, rather) ```python from __future__ import annotations from abc import ABC, abstractmethod, abstractproperty from dataclasses import dataclass from typing import List, Optional, Union, Iterator from torch import nn import torch from setfit.training_args import TrainingArguments @dataclass class SetFitHeadI(ABC): @abstractproperty def args_class(self) -> TrainingArguments: pass @abstractproperty def device(self) -> Optional[Union[str, torch.device]]: pass @abstractmethod def fit(self, x_train: List[str], y_train: List[int], **kwargs) -> None: pass @abstractmethod def to(self, device: Union[str, torch.device]) -> SetFitHeadI: pass @abstractmethod def train(self, mode: bool = True) -> None: pass @abstractmethod def parameters(self) -> Iterator[nn.Parameter]: pass @abstractmethod def predict(self, inputs: torch.Tensor) -> torch.Tensor: pass @abstractmethod def predict_proba(self, inputs: torch.Tensor) -> torch.Tensor: pass ```

Note that the inputs and outputs for predict and predict_proba are torch.Tensors to prevent type issues. Then, an implementation of this interface may be:

Example head implementation ```python from __future__ import annotations from dataclasses import dataclass, field from typing import Iterator, List, Optional, Union import numpy as np from torch import nn import torch from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier from sklearn.multioutput import ClassifierChain, MultiOutputClassifier from setfit.components.head.sklearn.training_args import SklearnHeadTrainingArguments from ..api import SetFitHeadI @dataclass class SetFitSklearnHead(SetFitHeadI): clf: LogisticRegression = None multi_target_strategy: str = None def __post_init__(self): if self.multi_target_strategy is not None: multi_target_mapping = { "one-vs-rest": OneVsRestClassifier, "multi-output": MultiOutputClassifier, "classifier-chain": ClassifierChain, } if self.multi_target_strategy in multi_target_mapping: self.clf = multi_target_mapping[self.multi_target_strategy](self.clf) else: raise ValueError( f"multi_target_strategy {self.multi_target_strategy!r} is not supported. " f"Choose one from {list(multi_target_mapping.keys())!r}." ) def to(self, device: Union[str, torch.device]) -> SetFitSklearnHead: return self @property def args_class(self) -> SklearnHeadTrainingArguments: return SklearnHeadTrainingArguments def fit(self, embeddings: torch.Tensor, y_train: List[int]) -> SetFitSklearnHead: embeddings = embeddings.detach().cpu().numpy() self.clf.fit(embeddings, y_train) return self def predict(self, inputs: torch.Tensor) -> torch.Tensor: X = inputs.detach().cpu().numpy() y_pred: np.ndarray = self.clf.predict(X) return torch.Tensor(y_pred) def predict_proba(self, inputs: torch.Tensor) -> torch.Tensor: X = inputs.detach().cpu().numpy() y_pred: np.ndarray = self.clf.predict_proba(X) return torch.Tensor(y_pred) @property def device(self) -> Optional[Union[str, torch.device]]: return None def parameters(self) -> Iterator[nn.Parameter]: return [] def train(self, mode: bool = True) -> None: pass @dataclass class SetFitLogisticHead(SetFitSklearnHead): clf: LogisticRegression = field(default_factory=LogisticRegression) ```

The SetFitSklearnHead class wraps around a Sklearn classifier, e.g. LogisticRegression. Additionally, the SetFitLogisticHead is a simple subclass of the aforementioned sklearn wrapper class, and it sets the classifier to be a LogisticRegression by default.

Where do I stand?

As you may suggest based on the detailed examples, I have started work on this refactor. In particular, I have set up the very first steps of the implementation in my refactor branch. It's a working MVP, but it's still far from polished. Note that I have so far attempted to implement the complex structure described above.

Using this branch, the following snippet works correctly, regardless of the head that is chosen:

SetFit Demo ```python from datasets import load_dataset from sklearn.linear_model import LogisticRegression from setfit import SetFitModel, Trainer, sample_dataset from setfit.components.head.linear.model import SetFitLinearHead from setfit.components.head.logistic.model import SetFitLogisticHead from setfit.training_args import TrainingArguments # Load a dataset from the Hugging Face Hub dataset = load_dataset("sst2") # Simulate the few-shot regime by sampling 8 examples per class train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8) eval_dataset = dataset["validation"] # Load a SetFit model from Hub using whatever head we want. head = SetFitLogisticHead(clf=LogisticRegression(solver="liblinear")) # head = SetFitLogisticHead() # head = SetFitLinearHead() model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2", head=head) # Create trainer trainer = Trainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, # embedding_body_args=TrainingArguments(), # classifier_body_args=TrainingArguments(), # classifier_head_args=TrainingArguments(), column_mapping={"sentence": "text", "label": "label"}, ) # Train and evaluate trainer.train() metrics = trainer.evaluate() print(metrics) ```

Both the SetFitLogisticHead and SetFitLinearHead are implemented. We can pass head parameters directly to the initializations of these heads. If no head parameter is used in SetFitModel.from_pretrained, then it simply initialises a SetFitLogisticHead. Furthermore, TrainingArguments can be passed using any of the embedding_body_args, classifier_body_args and classifier_head_args parameters.

So, where do we stand now? Let's refer back to our original issues:

  1. Trainer and TrainingArguments are implemented. However, it is uncommon to have to initialize several of them.
  2. Our heads have been split into separate files.
  3. Only one trainer.train() is needed, even if we use SetFitLinearHead. It should be easy to add freeze_body and freeze_head parameters to the TrainingArguments.
  4. use_differentiable_head=True was replaced by the more flexible head parameter.
  5. Head arguments are now applied by initializing a head before calling SetFitModel.from_pretrained, and no longer need to be passed like head_params={"out_features": num_classes}.

You'll find that my primary goals are satisfied with this approach, too. However, some issues remain:

  1. Having embedding_body_args, classifier_body_args and classifier_head_args parameters may be difficult to understand without knowledge about SetFit's model structure.
  2. Having embedding_body_args, classifier_body_args and classifier_head_args parameters may be confusing in certain situations, for example, the classifier_body_args would never be used if the head is a simple logistic regression head. Perhaps we can also move towards having only embedding_args and classifier_args? This would already simplify things a lot. After all, in the current approach we only use learning_rate from the entire classifier_body_args TrainingArguments instance.
  3. Some very important parameters such as multi_target have to be specified somewhere, preferably when the model is initialized. Perhaps a solution is to add it as a parameter to SetFitModel.__init__ and SetFitModel.from_pretrained?
  4. The end-to-end training from the linear head needs to be implemented on the level of SetFitModel rather than SetFitLinearHead, because it involves both the head and the body. This may be limited or problematic especially if we opt for the complex structure.
  5. I have yet to add support for the trainer distillation, and important features like hyperparameter tuning, exporting, saving and loading models have also not been implemented.
  6. There are several open PRs that might be best suited to be included in a larger refactoring like this, such as #187, #203, #207 & #212. I haven't absorbed these into the refactor, with exception of #203.

What do I want from you reading this?

This discussion is partially to share my progress, partially to look for any form of feedback, and partially to discuss the remaining issues that I last enumerated. More eyes on a problem may help create better solutions. Most crucial of all is feedback on the eventual user experience, which is the main thing that I intend to improve on. With other words, whether you like the updated approach from the last demo snippet.

I'm also considering switching from the complex structure to the standard structure, as it simplifies some things here and there. The complex structure only gives benefits if we would actually ever implement a different body, while it introduces extra complexity from a development and maintenance perspective.

Furthermore, if you have different goals that you value more highly (e.g. having only one TrainingArguments instance or something), then please let me know, too. That way we can reach a solution that satisfies us all.

I'll slowly continue working on this monday onwards, and hope that I'll be able to fully implement it to my satisfaction.

cc: @lewtun @blakechi you may be interested, too :)

tomaarsen commented 1 year ago

Somewhat related to this, I will be looking further into the Hugging Face Trainer and TrainingArguments classes, in the hopes that perhaps a good solution would be to rely on them. I'll also want to have a look at skops for safely saving and loading scikit-learn model heads.

tomaarsen commented 1 year ago

Further experimentation has made me quite dislike having multiple TrainingArguments instances. I'll try to see if I can get away with just the one, i.e. exactly the familiar structure of

args = TrainingArguments(...)
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=args,
    ...
)

The primary issue is that some parameters like learning_rate can be used in three separate instances:

If instead we use a logistic regression head, or any head that doesn't involve learning the model end-to-end, then we need one learning rate for fitting the head.

With other words, I want users to be able to specify two or three learning rates using the TrainingArguments instance. I can't think of a very intuitive approach for this. Perhaps something along the lines of:

@dataclass
class TrainingArguments:
    body_learning_rate_embedding: float = 2e-5
    body_learning_rate_classifier: float = 1e-5
    head_learning_rate: float = 1e-2

However, this is not super intuitive if a logistic head is used, as body_learning_rate_classifier would never be used. Another solution is:

@dataclass
class TrainingArguments:
    body_learning_rate: Union[float, Tuple[float, float]] = (2e-5, 1e-5)
    head_learning_rate: float = 1e-2

    def __post_init__(self):
        if isinstance(self.body_learning_rate, float):
            self.body_learning_rate = (self.body_learning_rate, self.body_learning_rate)

Then, body_learning_rate can either be set as one float if the learning rate should be equivalent in the two cases (or if there's only one time that the body is trained, i.e. when using a logistic head). And a tuple can be set if different learning rates are desired.

Another possibility is this:

@dataclass
class TrainingArguments:
    embedding_learning_rate: float = 2e-5
    classifier_learning_rate: Union[float, Tuple[float, float]] = (1e-5, 1e-2)

    def __post_init__(self):
        if isinstance(self.classifier_learning_rate, float):
            self.classifier_learning_rate = (self.embedding_learning_rate, self.classifier_learning_rate)

Then, the classifier_learning_rate tuple represents the learning rate for the body and the learning rate for the head. If a single float is supplied, then the learning rate from the embedding_learning_rate is used for the body. Like the previous example, this is more intuitive in the case of a logistic regression head.

I'm leaning towards one of the latter two approaches.

kgourgou commented 1 year ago

Hi Tom, I have a question.

A natural consequence of issue 2 is that implementing a new head is quite problematic: it requires editing many different files, and may introduce bugs in other contexts (e.g. when using different heads).

I'm still using an older version of SetFit for my tasks, but I haven't found the implementation of a new head to be an issue. I could see how this could be a pain for end-to-end training. Is that what you mean?

tomaarsen commented 1 year ago

Hello @kgourgou,

Thank you for the question! I'm primarily referring to sections of code such as this: https://github.com/huggingface/setfit/blob/6ee9b9d9efbffb2bcc364ba521007816ee89b6cf/src/setfit/modeling.py#L221-L252 https://github.com/huggingface/setfit/blob/6ee9b9d9efbffb2bcc364ba521007816ee89b6cf/src/setfit/modeling.py#L342-L345

We either fit using PyTorch or sklearn in this example. However, what if we or a user wants to introduce a third variant that doesn't work with either of these branches? Then this if-else would need another branch, i.e. code for SetFitModel would need to be modified, which I want(ed) to avoid. My intention was to refactor such that only code in the relevant (new) head would need to be updated for correct behaviour: the idea being that the methods are simply deferred to methods on the head. However, after some experimentation I now realise that the end-to-end fitting can't be implemented on the head, as it involves both the head and the body.

That said, I'm still conceptually interested in creating an interface or abstract class that clearly defines what must be implemented by a head in order for it to work successfully in SetFit. This may also help clear up some issues around types (in particular torch Tensor vs numpy array).

A consequence of this all is that methods such as to would be implemented like so:

    self.model_body.to(device)
    self.model_head.to(device)

rather than https://github.com/huggingface/setfit/blob/6ee9b9d9efbffb2bcc364ba521007816ee89b6cf/src/setfit/modeling.py#L342-L345

which means that all heads must implement a method called to (and train, eval, freeze, unfreeze), as these are all deferred to the heads. As you can imagine, having to add a bunch of methods that are just

    def to(self, device: Union[str, torch.device]) -> SetFitI:
        return self

isn't great...

Perhaps a solution is to not add those methods to the abstract class (i.e. not require their implementation), and performing something like inspect.getmembers to check if the method is implemented on the head. Not a great fan of that either, though.

Would love to hear your thoughts on an abstract class like this, as I'm open to reducing on my changes to the heads and focusing on the Trainer/TrainingArguments.

lewtun commented 1 year ago

Thanks for this very detailed and well thought out issue @tomaarsen πŸ”₯ !

I need some time to go through it in detail, but a few quick comments:

I'll respond with more thoughts after the vacations :)

tomaarsen commented 1 year ago

Hello Lewis,

Thank you for your comments! I'll go over them briefly:

  1. In my comment https://github.com/huggingface/setfit/issues/238#issuecomment-1359959752 I also expressed a desire to divert from my original plans and go for a single TrainingArguments class. I'm glad we agree here.
  2. I'm happy to see that we agree on the importance of having a single train() method.
  3. Changing the default model head when initializing a SetFitModel from a SentenceTransformer should be elementary, and I can understand the reasoning. Another potential pro of a purely Torch approach is that we may be able to adopt Trainer features from the transformers Trainer more easily. The sklearn support will keep this tricky though. I'm hesitant to fully remove sklearn support due to the much faster fitting time of the sklearn heads. I believe that one of SetFit's strengths is its training speed, and I'm reluctant to mess with that.
tomaarsen commented 11 months ago

Will be resolved by #439.

tomaarsen commented 11 months ago

Closed via #439