huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.25k stars 223 forks source link

Error using sampling_strategy="unique" #465

Closed traumgedanken closed 10 months ago

traumgedanken commented 11 months ago

Python version: 3.11.4 Dependencies installed:

setfit==1.0.1
torch==2.1.2
torchvision==0.16.2

I am using example code from the README file, the only thing sampling_strategy is set to unique:

from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    labels=["negative", "positive"],
)

args = TrainingArguments(
    batch_size=16,
    num_epochs=1,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    sampling_strategy="unique",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    column_mapping={
        "sentence": "text",
        "label": "label",
    },  # Map dataset columns to text/label expected by trainer
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)

This code raises the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/example.py", line 40, in <module>
    trainer.train()
  File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", line 410, in train
    self.train_embeddings(*full_parameters, args=args)
  File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", line 443, in train_embeddings
    train_dataloader, loss_func, batch_size = self.get_dataloader(
                                              ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", line 513, in get_dataloader
    dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/setfit/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 307, in __init__
    raise ValueError(
ValueError: DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle=True

In File "/opt/conda/envs/setfit/lib/python3.11/site-packages/setfit/trainer.py", lines 511-513 I see that shuffle is always True for sampling_strategy="unique":

shuffle_sampler = True if args.sampling_strategy == "unique" else False
batch_size = min(args.embedding_batch_size, len(data_sampler))
dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False)

At the same time in 'File "/opt/conda/envs/setfit/lib/python3.11/site-packages/torch/utils/data/dataloader.py", lines 302-308' I see that shuffle can be True only for datasets inherited from IterDataPipe:

if isinstance(dataset, IterDataPipe):
    if shuffle is not None:
        dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
elif shuffle not in {False, None}:
    raise ValueError(
        f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}")

But setfit uses ContrastiveDataset which is not inherited from IterDataPipe:

from torch.utils.data import IterDataPipe
from setfit.sampler import ContrastiveDataset

isinstance(ContrastiveDataset([], True), IterDataPipe)
>>> False
tomaarsen commented 10 months ago

Hello!

Thanks a bunch for the detailed issue! I'll remove the shuffle_sampler = ... lines, which should solve this problem.