huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.25k stars 26.09k forks source link

key error when use trainer to fine_tuning a dataset #9636

Closed XiaoYang66 closed 3 years ago

XiaoYang66 commented 3 years ago

Environment info

Who can help

@sgugger

Information

Model I am using (Bert, XLNet ...):bert-base-uncased

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

error File "train.py", line 69, in <module> trainer.train() File "/home/pliu3/projects/anaconda3/envs/calibration/lib/python3.7/site-packages/transformers/trainer.py", line 784, in train for step, inputs in enumerate(epoch_iterator): File "/home/pliu3/projects/anaconda3/envs/calibration/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 435, in __next__ data = self._next_data() File "/home/pliu3/projects/anaconda3/envs/calibration/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 475, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/home/pliu3/projects/anaconda3/envs/calibration/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/pliu3/projects/anaconda3/envs/calibration/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp> data = [self.dataset[idx] for idx in possibly_batched_index] KeyError: 2

code

dataset_name = 'sem_eval_2014_task_1'

num_labels_size = 3

batch_size = 4

model_checkpoint = 'bert-base-uncased'

number_train_epoch = 5

def tokenize(batch):
    return tokenizer(batch['premise'], batch['hypothesis'],  truncation=True, )

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='micro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

model = BertForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels_size)
tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint, use_fast=True)

train_dataset = load_dataset(dataset_name, split='train')
test_dataset = load_dataset(dataset_name, split='test')

train_encoded_dataset = train_dataset.map(tokenize, batched=True)
test_encoded_dataset = test_dataset.map(tokenize, batched=True)

args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=number_train_epoch,
    weight_decay=0.01,
    do_predict=True
)
trainer = Trainer(
    model=model,
    args=args,
    compute_metrics=compute_metrics,
    train_dataset=train_encoded_dataset,
    eval_dataset=test_encoded_dataset,
    tokenizer=tokenizer
)

trainer.train()
trainer.evaluate()
XiaoYang66 commented 3 years ago

I found this Jesus is caused by this description Here we have the loss since we passed along labels(url:https://huggingface.co/transformers/main_classes/output.html).so if the column dataset object do not have label(or if the column which represents label have other name ,like'entailment_judgment').the trainer can not recognize this column .

XiaoYang66 commented 3 years ago

so I add some line like this : def change_transformers_dataset_2_right_format(dataset, label_name): return dataset.map(lambda example: {'label': example[label_name]}, remove_columns=[label_name]).it works fine.

XiaoYang66 commented 3 years ago

I found a lot of dataset ,upload by user, the name of the column which represents 'label' have other name! maybe it is better to unify a standard either on dataset or on trainer

XiaoYang66 commented 3 years ago

and I can not visit your forum .I do not know why.and this is wired.can you please help me.thanks a lot!

sgugger commented 3 years ago

The script is not meant to work out of the box on any dataset, it is an example. If the columns are named differently than the usual glue datasets, it's logical you have to change one line.

Please do not post the same issues several times.

XiaoYang66 commented 3 years ago

ok, thanks for your reply .and do you know why I can not visit your forum? is there some special setting in you firewall for your forum? @sgugger

sgugger commented 3 years ago

I'm not aware of any firewall problem, you're the first user reporting an issue to connect to them, to be honest.

hg0428 commented 1 year ago

I have this same problem.

from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

def train(model, train, eval, **kwargs):
  print('Training model...')
  trainer = Trainer(
      model=model,
      train_dataset=train, #Dataset to train it with
      eval_dataset=eval, #Dataset to test it with
      compute_metrics=compute_metrics,
      **kwargs
  ) 
  trainer.train()
  trainer.save_model('adkai')
  print('Trained!')

model.train(True)
train(model, {
  '#print Hello World':'stdout.write("Hello World\n")',
  '#print hello World':'stdout.write("hello World\n")',
  '# print Hello world':'stdout.write("Hello world\n")',
  '#print hello world':'stdout.write("hello world\n")',
  '#print Hello World!':'stdout.write("Hello World!\n")',
  '# print hello World!':'stdout.write("hello World!\n")',
  '#print goodbye World!':'stdout.write("goodbye World!\n")',
  '# write Hello World':'stdout.write("Hello World\n")',
  '#write hello World':'stdout.write("hello World\n")',
  '# write Hello world':'stdout.write("Hello world\n")',
  '#write hello world':'stdout.write("hello world\n")',
  '# write Hello World!':'stdout.write("Hello World!\n")',
  'set x = 5\n#print x':'stdout.write(x, "\n")',
  'set x = "Go home"\n#output x':'stdout.write(x, "\n")',
  'set xyz = "Hello"# output xyz':'stdout.write(xyz, "\n")', 
  'set Whatever = "nothing"\n#output Whatever':'stdout.write(Whatever, "\n")',
  '#output Whatever':'stdout.write("Whatever\n")',
  '':'',
  '':''
}, {
  '#write Hello world!':'stdout.write("Hello world!\n")',
  '':'',
  '# output Hello World!':'stdout.write("Hello World!\n")',
})

(only partial code)

Please help, this is the error

Traceback (most recent call last):
  File "main.py", line 18, in <module>
    train.train(model, {
  File "/home/runner/AdkAI/train.py", line 23, in train
    trainer.train()
  File "/home/runner/AdkAI/venv/lib/python3.8/site-packages/transformers/trainer.py", line 1500, in train
    return inner_training_loop(
  File "/home/runner/AdkAI/venv/lib/python3.8/site-packages/transformers/trainer.py", line 1716, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/home/runner/AdkAI/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/runner/AdkAI/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 721, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/runner/AdkAI/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/runner/AdkAI/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
KeyError: 2