huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.78k stars 26.95k forks source link

TFTrainingArguments #9053

Closed tangzhy closed 3 years ago

tangzhy commented 3 years ago

Environment info

@sgugger @jplu @stefan-it

Information

Model I am using (Bert, XLNet ...): Bert

The problem arises when using: [ ] the official example scripts: (give details below) [x] my own modified scripts: (give details below)

The tasks I am working on is: [ ] an official GLUE/SQUaD task: (give the name) [x] my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. specifiy training args.
  2. run trainer
  3. raise Exception where evaluation_strategy in training_args becomes evaluate_strategy
training_args = TFTrainingArguments(
    output_dir="/root/Data/marco-passage-ranking/results",
    overwrite_output_dir=True, 
    do_train=True, 
    do_eval=True,
    do_predict=False, 
    evaluation_strategy="no",
    eval_steps=1000,

    per_device_train_batch_size=8,  # batch size per device during training
    per_device_eval_batch_size=8,   # batch size for evaluation

    learning_rate=1e-6, 

    max_steps=400000,
    warmup_steps=40000,   

    logging_dir="./tmp/log", 
    logging_steps=1000, 
    save_steps=1000,

    fp16=False, 

#     eval_steps=1000, 
    xla =False
)

trainer = TFTrainer(
    model=model,                        
    args=training_args,             
    train_dataset=train_ds.take(100000),
    eval_dataset=dev_ds.take(10000), 
    compute_metrics=compute_metrics,
)

trainer.train()

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-19-25eb465360cc> in <module>
      7 )
      8 
----> 9 trainer.train()

~/Softwares/anaconda3/envs/tf2.0/lib/python3.7/site-packages/transformers/trainer_tf.py in train(self)
    562                     if (
    563                         self.args.eval_steps > 0
--> 564                         and self.args.evaluate_strategy == EvaluationStrategy.STEPS
    565                         and self.global_step % self.args.eval_steps == 0
    566                     ):

AttributeError: 'TFTrainingArguments' object has no attribute 'evaluate_strategy'

I think this might be a bug where the inconsistency of eval_strategy name raises Exception. Any advice?

sgugger commented 3 years ago

Oh this is a typo, do you want to open a PR to fix it?