UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.27k stars 2.47k forks source link

Can not finetune a model using an streaming / torch.utils.data.IterableDataset #2232

Open stamm1989 opened 1 year ago

stamm1989 commented 1 year ago

I'm currently trying to finetune the "bertje" model. I'm expecting to have a large dataset which exceeds my working memory of the machine i'm using. After some reading I found that the torch.utils.data.IterableDataset would be the solution for this, in combinations with potentially a webdataset data format.

However, the SentenceTransformer.fit function tries to retrieve the length of the dataset a couple of times, len(dataloader). Which by design isn't there since we do now know the length in advance. https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py#L629 https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/model_card_templates.py#L162 https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py#L656

I've also noted that there is a custom data_loader implementation, but this also does not seem to be the solution since it also requires me to put in the entire set in memory. https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/datasets/NoDuplicatesDataLoader.py

Can this become supported? Or is it already supported, but should I write my own custom dataloader class ?

Small code snippet:

import torch
from sentence_transformers import SentenceTransformer, util, losses, InputExample, datasets
bert_model = SentenceTransformer('GroNLP/bert-base-dutch-cased')
train_loss = losses.MultipleNegativesRankingLoss(model=bert_model)
input_examples = [
    InputExample(texts=['some text', 'some other text']),
    InputExample(texts=['some nice text', 'some other nice  text'])
    ]

class MyIterableDataSet(torch.utils.data.IterableDataset):
  def __init__(self, data):
    super(MyIterableDataSet, self).__init__()
    self.data = data

  def __iter__(self):
    return iter(self.data)

class MyInMemDataSet(torch.utils.data.Dataset):
  def __init__(self, data):
    super(MyInMemDataSet, self).__init__()
    self.data = data

  def __getitem__(self, idx):
    return self.data[idx]

  def __len__(self):
    return(len(self.data))

# Doesn't work
dataset_iterable = MyIterableDataSet(input_examples)
dataloader_iterable = torch.utils.data.dataloader.DataLoader(dataset_iterable)
#Tune the model
bert_model.fit(
  train_objectives=[
    (dataloader_iterable, train_loss)
  ],
  epochs=1,
)

# Does work
dataset_in_memory = MyInMemDataSet(input_examples)
dataloader_iterable_in_memory = torch.utils.data.dataloader.DataLoader(dataset_in_memory)
bert_model.fit(
  train_objectives=[
    (dataloader_iterable_in_memory, train_loss)
  ],
  epochs=1,
)
olivierr42 commented 5 months ago

With today's release of SentenceTransformers V3, is this issue fixed? I was looking into using an iterable dataset as well with Ray Train and was wondering if the SentenceTransformerTrainer will work seemlessly.

stamm1989 commented 4 months ago

Thank you for the update, I have not been able to verify if it now works. I diverted into using some other packages.