UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
14.89k stars 2.44k forks source link

Using SentenceTransformerTrainer in case of Asym module usage. #2742

Open tonysternenko opened 3 months ago

tonysternenko commented 3 months ago

Hi there. I'm wondering, whether it is possible to use SentenceTransformerTrainer class for training model, that includes Asym module in it's structure?

I'm asking, because due to the documentation, SentenceTransformerTrainer class doesn't accept entities of List[InputExamples], that are required for a proper asymmetric model training.

BohdanBilonoh commented 3 months ago

Hello! It could be done through custom data_collator for SentenceTransformerTrainer:

@dataclass
class AsymDataCollator:

    tokenize_fn: Callable
    valid_label_columns: List[str] = field(default_factory=lambda: ["label", "score"])

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        columns = list(features[0].keys())

        # We should always be able to return a loss, label or not:
        batch = {}

        if "dataset_name" in columns:
            columns.remove("dataset_name")
            batch["dataset_name"] = features[0]["dataset_name"]

        # Extract the label column if it exists
        for label_column in self.valid_label_columns:
            if label_column in columns:
                batch["label"] = torch.tensor([row[label_column] for row in features])
                columns.remove(label_column)
                break

        grouped_features = {
            key: [{key: d[key]} for d in features]
            for key in features[0].keys()
            if key not in self.valid_label_columns
        }
        for i, key in enumerate(grouped_features):
            batch.update(
                {
                    f"sentence{i}_{token_key}": value
                    for token_key, value in self.tokenize_fn(
                        [{key: d[key]} for d in features]
                    ).items()
                }
            )

        return batch

...

model = SentenceTransformer(...)

...

trainer = SentenceTransformerTrainer(
    ...
    data_collator=AsymDataCollator(tokenize_fn=model.tokenize),
)