huggingface / optimum-habana

Easy and lightning fast training of 🤗 Transformers on Habana Gaudi processor (HPU)
Apache License 2.0
152 stars 198 forks source link

Resume from checkpoint does not work #5

Closed MohitIntel closed 2 years ago

MohitIntel commented 2 years ago

Error Message:

Traceback (most recent call last):
  File "examples/question-answering/run_qa.py", line 664, in <module>
    main()
  File "examples/question-answering/run_qa.py", line 605, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/usr/local/lib/python3.8/dist-packages/optimum/habana/trainer.py", line 517, in train
    self._load_optimizer_and_scheduler(resume_from_checkpoint)
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1795, in _load_optimizer_and_scheduler
    torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 857, in persistent_load
    load_tensor(data_type, size, key, _maybe_decode_ascii(location))
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 846, in load_tensor
    loaded_storages[key] = restore_location(storage, location)
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 827, in restore_location
    return default_restore_location(storage, str(map_location))
  File "/usr/local/lib/python3.8/dist-packages/torch/serialization.py", line 178, in default_restore_location
    raise RuntimeError("don't know how to restore data location of "
RuntimeError: don't know how to restore data location of torch.FloatStorage (tagged with hpu)

Command used to run training :

python examples/question-answering/run_qa.py --model_name_or_path albert-xxlarge-v1 --dataset_name squad  --do_train --do_eval --per_device_train_batch_size=12 --learning_rate=5e-06 --num_train_epochs 2 --save_steps 5000 --seed 42 --doc_stride 128 --max_seq_length 384 --per_device_eval_batch_size 2 --use_lazy_mode  --use_habana --output_dir=./albert_xxlarge_bf16_squad 2>&1 | tee albert_xxlarge_bf16_squad_continued.log

Method for reproducing the issue:

  1. Use above command to run the training.
  2. Halt the training after few steps/epochs.
  3. Resume the training using the same command with --resume_from_checkpoint flag pointing to the output directory of the above command.
  4. Above error is encountered.

Attached Log file: albert_xxlarge_bf16_squad_continued.log

yeonsily commented 2 years ago

Actually the issue was caused by wrong checkpoint location. Previously we gave the location like this. '--resume_from_checkpoint ./output/checkpoint-3500' but it's supposed to be just ./output

It's working fine with the correct checkpoint path. This is an example command to verify it.

$ python run_qa.py --model_name_or_path roberta-base --gaudi_config_name ../gaudi_config.json --dataset_name squad --do_train --do_eval --per_device_train_batch_size 24 --per_device_eval_batch_size 8 --use_habana --use_lazy_mode --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir ./output/ --resume_from_checkpoint ./output/

yeonsily commented 2 years ago

Actually it's supposed to work with giving last saved checkpoint folder. e.g. --resume_from_checkpoint ./output/checkpoint-3500

We found that there's an issue in trainer side.

MohitIntel commented 2 years ago

Currently, the checkpoint resume does not work if the training run ends abruptly amidst an epoch. It does not pick up the global last saved checkpoint step. Instead, it picks up the last step that ended gracefully.

regisss commented 2 years ago

Could you tell me if you still encounter this issue with an up to date version of the package?

libinta commented 2 years ago

can't reproduce after pull request 11