huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.32k stars 238 forks source link

Resuming training fails #105

Open hidoba opened 2 months ago

hidoba commented 2 months ago

So I removed the flag --overwrite_output_dir to be able to resume the training, and I'm getting the following error:

04/01/2024 00:30:01 - INFO - __main__ - max_steps is given, it will override any value given in num_train_epochs
04/01/2024 00:30:04 - INFO - __main__ - ***** Running training *****
04/01/2024 00:30:04 - INFO - __main__ -   Num examples = 4800000
04/01/2024 00:30:04 - INFO - __main__ -   Instantaneous batch size per device = 8
04/01/2024 00:30:04 - INFO - __main__ -   Gradient accumulation steps = 1
04/01/2024 00:30:04 - INFO - __main__ -   Total train batch size (w. parallel & distributed) = 8
04/01/2024 00:30:04 - INFO - __main__ -   Total optimization steps = 600000
Train steps ... :   0%|                                                                               | 0/600000 [00:00<?, ?it/s]04/01/2024 00:30:04 - INFO - accelerate.accelerator - Loading states from ./checkpoint-5000-epoch-0
Traceback (most recent call last):
  File "/home/vlad/distil-whisper/training/run_distillation.py", line 1682, in <module>
    main()
  File "/home/vlad/distil-whisper/training/run_distillation.py", line 1484, in main
    accelerator.load_state(checkpoint)
  File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2966, in load_state
    load_accelerator_state(
  File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/accelerate/checkpointing.py", line 205, in load_accelerator_state
    models[i].load_state_dict(state_dict, **load_model_func_kwargs)
  File "/home/vlad/distil-whisper/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for WhisperForConditionalGeneration:
        Missing key(s) in state_dict: "proj_out.weight". 

At the same time, evaluation script works just fine with the same checkpoint.

I'm using Ubuntu 22, rtx 3090 ti.

hidoba commented 2 months ago

I've also observed this in the log:

04/01/2024 00:35:47 - WARNING - accelerate.utils.other - Removed shared tensor {'proj_out.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading
Gusreis7 commented 2 months ago

any updates on this ? I'm facing the same problem

George0828Zhang commented 2 months ago

Here's a temporary fix according to https://huggingface.co/docs/safetensors/torch_shared_tensors

Modify load_accelerator_state(): https://github.com/huggingface/accelerate/blob/main/src/accelerate/checkpointing.py#L153

-from safetensors.torch import load_file
+from safetensors.torch import load_model
...
    if input_model_file.exists():
-       state_dict = load_file(input_model_file, device=str(map_location))
+       load_model(models[i], input_model_file, device=str(map_location), **load_model_func_kwargs)
    else:
        # Load with torch
        input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
        state_dict = torch.load(input_model_file, map_location=map_location)
-   models[i].load_state_dict(state_dict, **load_model_func_kwargs)
+       models[i].load_state_dict(state_dict, **load_model_func_kwargs)