Open tonysternenko opened 3 months ago
Hello! It could be done through custom data_collator
for SentenceTransformerTrainer
:
@dataclass
class AsymDataCollator:
tokenize_fn: Callable
valid_label_columns: List[str] = field(default_factory=lambda: ["label", "score"])
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
columns = list(features[0].keys())
# We should always be able to return a loss, label or not:
batch = {}
if "dataset_name" in columns:
columns.remove("dataset_name")
batch["dataset_name"] = features[0]["dataset_name"]
# Extract the label column if it exists
for label_column in self.valid_label_columns:
if label_column in columns:
batch["label"] = torch.tensor([row[label_column] for row in features])
columns.remove(label_column)
break
grouped_features = {
key: [{key: d[key]} for d in features]
for key in features[0].keys()
if key not in self.valid_label_columns
}
for i, key in enumerate(grouped_features):
batch.update(
{
f"sentence{i}_{token_key}": value
for token_key, value in self.tokenize_fn(
[{key: d[key]} for d in features]
).items()
}
)
return batch
...
model = SentenceTransformer(...)
...
trainer = SentenceTransformerTrainer(
...
data_collator=AsymDataCollator(tokenize_fn=model.tokenize),
)
Hi there. I'm wondering, whether it is possible to use
SentenceTransformerTrainer
class for training model, that includesAsym
module in it's structure?I'm asking, because due to the documentation,
SentenceTransformerTrainer
class doesn't accept entities of List[InputExamples], that are required for a proper asymmetric model training.