microsoft / SpeechT5

Unified-Modal Speech-Text Pre-Training for Spoken Language Processing
MIT License
1.16k stars 113 forks source link

Difficulties loading pre-trained weights! #5

Closed sanchit-gandhi closed 2 years ago

sanchit-gandhi commented 2 years ago

Hello!

Thank you very much for adding a code snippet to outline how to load pre-trained SpeechT5 weights, super helpful for understanding how to process the data and load the task 😊

I've been attempting to load the 'base' pre-trained weights according to the code snippet provided here:

import torch
from speecht5.tasks.speecht5 import SpeechT5Task
from speecht5.models.speecht5 import T5TransformerModel

checkpoint = torch.load('/path/to/speecht5_checkpoint')

checkpoint['cfg']['task'].t5_task = 'pretrain'
checkpoint['cfg']['task'].hubert_label_dir = "/path/to/hubert_label"
checkpoint['cfg']['task'].data = "/path/to/tsv_file"

task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
model = T5TransformerModel.build_model(checkpoint['cfg']['model'], task)
model.load_state_dict(checkpoint['model'])

Steps performed:

  1. Downloaded the fine-tuned base checkpoint from https://github.com/microsoft/SpeechT5#pre-trained-models
  2. Create a dummy dict of Hubert labels with using the instructions provided here with n_clusters=500:
    for x in $(seq 0 $((n_clusters - 1))); do
    echo "$x 1"
    done >> $lab_dir/dict.km.txt
  3. Download the dummy text dictionary using the download link provided in @Ajyy 's previous issue response from the G Drive link. As outlined, the text dict should be placed under data and the Hubert labels under hubert_label_dir.
  4. Using the Hubert labels and text dict, running the aforementioned code snippet to load the pre-trained model. Loading the task throws an error:
    Click for full traceback
task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:242, in Dictionary.add_from_file(self, f)
    241 try:
--> 242     line, field = line.rstrip().rsplit(" ", 1)
    243     if field == "#fairseq:overwrite":

ValueError: not enough values to unpack (expected 2, got 1)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])

File ~/SpeechT5/SpeechT5/speecht5/tasks/speecht5.py:301, in SpeechT5Task.setup_task(cls, args, **kwargs)
    299 if args.t5_task == "pretrain":
    300     dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels]
--> 301     dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
    302 else:
    303     if config is None:

File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:216, in Dictionary.load(cls, f)
    207 """Loads the dictionary from a text file with the format:
    208 
    209 ```
   (...)
    213 ```
    214 """
    215 d = cls()
--> 216 d.add_from_file(f)
    217 return d

File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:227, in Dictionary.add_from_file(self, f)
    225 try:
    226     with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
--> 227         self.add_from_file(fd)
    228 except FileNotFoundError as fnfe:
    229     raise fnfe

File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:260, in Dictionary.add_from_file(self, f)
    258     self.add_symbol(word, n=count, overwrite=overwrite)
    259 except ValueError:
--> 260     raise ValueError(
    261         "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
    262     )

ValueError: Incorrect dictionary format, expected '<token> <cnt> [flags]'

  1. Do we need to add a flag column to the dummy text dict? I appended a column of zeros to the dummy text dict (giving token, count, 0). The task can then be loaded: task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
  2. Error with loading the weights:
    model = T5TransformerModel.build_model(checkpoint['cfg']['model'], task)
    model.load_state_dict(checkpoint['model'])
Click for full traceback ``` --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Input In [17], in () 1 model = T5TransformerModel.build_model(checkpoint['cfg']['model'], task) ----> 2 model.load_state_dict(checkpoint['model']) File ~/SpeechT5/SpeechT5/speecht5/models/speecht5.py:1040, in T5TransformerModel.load_state_dict(self, state_dict, strict, model_cfg, args) 1036 m_state_dict = { 1037 key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.") 1038 } 1039 if hasattr(self, m): -> 1040 self._modules[m].load_state_dict(m_state_dict, False) 1041 return self File ~/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1497, in Module.load_state_dict(self, state_dict, strict) 1492 error_msgs.insert( 1493 0, 'Missing key(s) in state_dict: {}. '.format( 1494 ', '.join('"{}"'.format(k) for k in missing_keys))) 1496 if len(error_msgs) > 0: -> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 1498 self.__class__.__name__, "\n\t".join(error_msgs))) 1499 return _IncompatibleKeys(missing_keys, unexpected_keys) RuntimeError: Error(s) in loading state_dict for TransformerEncoder: size mismatch for proj.weight: copying a param with shape torch.Size([81, 768]) from checkpoint, the shape in current model is torch.Size([7, 768]). size mismatch for proj.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([7]). ```

Would be very grateful to get some insight on these two questions:

  1. Do we need to process the dummy text data in an additional way to add the 'flag' column?
  2. Is the size mismatch error being thrown related to the saved PT checkpoint?

Many thanks for your help!

sanchit-gandhi commented 2 years ago

The issue was the fact that my dict.txt file was a plain text file rather than a binary! Able to load the state dict as expected!

StephennFernandes commented 2 years ago

@sanchit-gandhi are you working on converting SpeechT5 to Huggingface ?