I am using example code from the README file, the only thing sampling_strategy is set to unique:
from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-mpnet-base-v2",
labels=["negative", "positive"],
)
args = TrainingArguments(
batch_size=16,
num_epochs=1,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
sampling_strategy="unique",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
metric="accuracy",
column_mapping={
"sentence": "text",
"label": "label",
}, # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)
This code raises the following exception:
Traceback (most recent call last):
File "/home/ubuntu/example.py", line 40, in <module>
trainer.train()
File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", line 410, in train
self.train_embeddings(*full_parameters, args=args)
File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", line 443, in train_embeddings
train_dataloader, loss_func, batch_size = self.get_dataloader(
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", line 513, in get_dataloader
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/setfit/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 307, in __init__
raise ValueError(
ValueError: DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle=True
In File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", lines 511-513 I see that shuffle is always True for sampling_strategy="unique":
At the same time in 'File "/opt/conda/envs/setfit/lib/python3.11/site-packages/torch/utils/data/dataloader.py", lines 302-308' I see that shuffle can be True only for datasets inherited from IterDataPipe:
if isinstance(dataset, IterDataPipe):
if shuffle is not None:
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
elif shuffle not in {False, None}:
raise ValueError(
f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}")
But setfit uses ContrastiveDataset which is not inherited from IterDataPipe:
from torch.utils.data import IterDataPipe
from setfit.sampler import ContrastiveDataset
isinstance(ContrastiveDataset([], True), IterDataPipe)
>>> False
Python version:
3.11.4
Dependencies installed:I am using example code from the README file, the only thing
sampling_strategy
is set tounique
:This code raises the following exception:
In
File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", lines 511-513
I see thatshuffle
is alwaysTrue
forsampling_strategy="unique"
:At the same time in 'File "/opt/conda/envs/setfit/lib/python3.11/site-packages/torch/utils/data/dataloader.py", lines 302-308' I see that
shuffle
can beTrue
only for datasets inherited fromIterDataPipe
:But setfit uses
ContrastiveDataset
which is not inherited fromIterDataPipe
: