Open AbhiPawar5 opened 10 months ago
Hello!
SetFit will train using pairs of samples, rather than individual samples themselves. Consequently, you can imagine that the trainer may create a completely unreasonable number of pairs from 800k samples. My recommendation is to set max_steps
and eval_max_steps
fairly low, this should limit the number of pairs created, preventing you from crashing. You can then also increase the batch size most likely, as that shouldn't be causing the issues.
Hi @tomaarsen, thanks for your replying so quickly! I did set the max_steps
and eval_max_steps
to 1 but it didn't help. I also set num iterations to 1 to avoid large pairs in memory.
Please find the code below:
args = TrainingArguments(
batch_size=1,
num_epochs=1,
max_steps=1,
eval_max_steps=1,
num_iterations=1
)
Can you think of any other reason that is causing the memory issue?
Hmm, it certainly should be working if the max_steps is smaller. It's probably because the datasets are just absolutely massive. I think the only solution is to also cut down the datasets passed to the Trainer:
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"].select(range(10000))
)
(10k is just an example, most SetFit models train with much less samples)
Cool. Thanks for your help.
It seems the issue resides in how the sampling pairs are generated (in full :O):
See shuffle-_combinations
, which considers the WHOLE dataset, it should be a generator in full, rather than a generation of paired indices follow by permuted sampling.
`def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: """Generates shuffled pair combinations for any iterable data provided.
Args:
iterable: data to generate pair combinations from
replacement: enable to include combinations of same samples,
equivalent to itertools.combinations_with_replacement
Returns:
Generator of shuffled pairs as a tuple
"""
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1) --> this line
`
Hi @Jordy-VL, I'm not sure I understand your comment. How do I use this for my code? Thanks
I am not providing a solution, just pointing to where the issue resides. A solution would require being smarter about how samples are created for a large dataset. For example, by creating a custom ContrastiveDataset class that uses minimal compute to generate the paired samples with more control on the generation process.
Hi team, I am using setfit for a multiclass classification problem (130+ classes). I have ~800,000K labelled samples as training set and ~200,000K as test set. I see my kernel crashing even though I have 1 batch on my 32GB RAM on my MacBook M1 Pro.
My Train and Test CSV has the same labels.
Code to reproduce: