huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.24k stars 223 forks source link

Kernel crash due to out of memory for large dataset #472

Open AbhiPawar5 opened 10 months ago

AbhiPawar5 commented 10 months ago

Hi team, I am using setfit for a multiclass classification problem (130+ classes). I have ~800,000K labelled samples as training set and ~200,000K as test set. I see my kernel crashing even though I have 1 batch on my 32GB RAM on my MacBook M1 Pro.

My Train and Test CSV has the same labels.

Code to reproduce:

import pandas as pd
from datasets import load_dataset

from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset

df = pd.read_csv("Combined Train.csv", usecols=['label'])
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", labels=df.label)

dataset = load_dataset('csv', data_files={
"train": 'Combined Train.csv',
"test": 'Combined Test.csv'
})

# Preparing the training arguments
args = TrainingArguments(
    batch_size=2,
    num_epochs=1,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train']
)

trainer.train()
tomaarsen commented 10 months ago

Hello!

SetFit will train using pairs of samples, rather than individual samples themselves. Consequently, you can imagine that the trainer may create a completely unreasonable number of pairs from 800k samples. My recommendation is to set max_steps and eval_max_steps fairly low, this should limit the number of pairs created, preventing you from crashing. You can then also increase the batch size most likely, as that shouldn't be causing the issues.

AbhiPawar5 commented 10 months ago

Hi @tomaarsen, thanks for your replying so quickly! I did set the max_steps and eval_max_steps to 1 but it didn't help. I also set num iterations to 1 to avoid large pairs in memory.

Please find the code below:

args = TrainingArguments(
    batch_size=1,
    num_epochs=1,
    max_steps=1,
    eval_max_steps=1,
    num_iterations=1
)

Can you think of any other reason that is causing the memory issue?

tomaarsen commented 10 months ago

Hmm, it certainly should be working if the max_steps is smaller. It's probably because the datasets are just absolutely massive. I think the only solution is to also cut down the datasets passed to the Trainer:

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"].select(range(10000))
)

(10k is just an example, most SetFit models train with much less samples)

AbhiPawar5 commented 10 months ago

Cool. Thanks for your help.

Jordy-VL commented 9 months ago

It seems the issue resides in how the sampling pairs are generated (in full :O):

See shuffle-_combinations, which considers the WHOLE dataset, it should be a generator in full, rather than a generation of paired indices follow by permuted sampling.

`def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: """Generates shuffled pair combinations for any iterable data provided.

Args:
    iterable: data to generate pair combinations from
    replacement: enable to include combinations of same samples,
        equivalent to itertools.combinations_with_replacement

Returns:
    Generator of shuffled pairs as a tuple
"""
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1) --> this line

`

AbhiPawar5 commented 9 months ago

Hi @Jordy-VL, I'm not sure I understand your comment. How do I use this for my code? Thanks

Jordy-VL commented 9 months ago

I am not providing a solution, just pointing to where the issue resides. A solution would require being smarter about how samples are created for a large dataset. For example, by creating a custom ContrastiveDataset class that uses minimal compute to generate the paired samples with more control on the generation process.