Thank you very much for adding a code snippet to outline how to load pre-trained SpeechT5 weights, super helpful for understanding how to process the data and load the task 😊
I've been attempting to load the 'base' pre-trained weights according to the code snippet provided here:
import torch
from speecht5.tasks.speecht5 import SpeechT5Task
from speecht5.models.speecht5 import T5TransformerModel
checkpoint = torch.load('/path/to/speecht5_checkpoint')
checkpoint['cfg']['task'].t5_task = 'pretrain'
checkpoint['cfg']['task'].hubert_label_dir = "/path/to/hubert_label"
checkpoint['cfg']['task'].data = "/path/to/tsv_file"
task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
model = T5TransformerModel.build_model(checkpoint['cfg']['model'], task)
model.load_state_dict(checkpoint['model'])
Create a dummy dict of Hubert labels with using the instructions provided here with n_clusters=500:
for x in $(seq 0 $((n_clusters - 1))); do
echo "$x 1"
done >> $lab_dir/dict.km.txt
Download the dummy text dictionary using the download link provided in @Ajyy 's previous issue response from the G Drive link. As outlined, the text dict should be placed under data and the Hubert labels under hubert_label_dir.
Using the Hubert labels and text dict, running the aforementioned code snippet to load the pre-trained model. Loading the task throws an error:
Click for full traceback
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:242, in Dictionary.add_from_file(self, f)
241 try:
--> 242 line, field = line.rstrip().rsplit(" ", 1)
243 if field == "#fairseq:overwrite":
ValueError: not enough values to unpack (expected 2, got 1)
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Input In [10], in <cell line: 1>()
----> 1 task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
File ~/SpeechT5/SpeechT5/speecht5/tasks/speecht5.py:301, in SpeechT5Task.setup_task(cls, args, **kwargs)
299 if args.t5_task == "pretrain":
300 dicts["hubert"] = [Dictionary.load(f"{args.hubert_label_dir}/dict.{label}.txt") for label in args.hubert_labels]
--> 301 dicts["text"] = Dictionary.load(op.join(args.data, "dict.txt"))
302 else:
303 if config is None:
File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:216, in Dictionary.load(cls, f)
207 """Loads the dictionary from a text file with the format:
208
209 ```
(...)
213 ```
214 """
215 d = cls()
--> 216 d.add_from_file(f)
217 return d
File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:227, in Dictionary.add_from_file(self, f)
225 try:
226 with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
--> 227 self.add_from_file(fd)
228 except FileNotFoundError as fnfe:
229 raise fnfe
File ~/SpeechT5/SpeechT5/fairseq/fairseq/data/dictionary.py:260, in Dictionary.add_from_file(self, f)
258 self.add_symbol(word, n=count, overwrite=overwrite)
259 except ValueError:
--> 260 raise ValueError(
261 "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
262 )
ValueError: Incorrect dictionary format, expected '<token> <cnt> [flags]'
Do we need to add a flag column to the dummy text dict? I appended a column of zeros to the dummy text dict (giving token, count, 0). The task can then be loaded:
task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
Error with loading the weights:
model = T5TransformerModel.build_model(checkpoint['cfg']['model'], task)
model.load_state_dict(checkpoint['model'])
Click for full traceback
```
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [17], in ()
1 model = T5TransformerModel.build_model(checkpoint['cfg']['model'], task)
----> 2 model.load_state_dict(checkpoint['model'])
File ~/SpeechT5/SpeechT5/speecht5/models/speecht5.py:1040, in T5TransformerModel.load_state_dict(self, state_dict, strict, model_cfg, args)
1036 m_state_dict = {
1037 key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
1038 }
1039 if hasattr(self, m):
-> 1040 self._modules[m].load_state_dict(m_state_dict, False)
1041 return self
File ~/venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1497, in Module.load_state_dict(self, state_dict, strict)
1492 error_msgs.insert(
1493 0, 'Missing key(s) in state_dict: {}. '.format(
1494 ', '.join('"{}"'.format(k) for k in missing_keys)))
1496 if len(error_msgs) > 0:
-> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1498 self.__class__.__name__, "\n\t".join(error_msgs)))
1499 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for TransformerEncoder:
size mismatch for proj.weight: copying a param with shape torch.Size([81, 768]) from checkpoint, the shape in current model is torch.Size([7, 768]).
size mismatch for proj.bias: copying a param with shape torch.Size([81]) from checkpoint, the shape in current model is torch.Size([7]).
```
Would be very grateful to get some insight on these two questions:
Do we need to process the dummy text data in an additional way to add the 'flag' column?
Is the size mismatch error being thrown related to the saved PT checkpoint?
Hello!
Thank you very much for adding a code snippet to outline how to load pre-trained SpeechT5 weights, super helpful for understanding how to process the data and load the task 😊
I've been attempting to load the 'base' pre-trained weights according to the code snippet provided here:
Steps performed:
n_clusters=500
:data
and the Hubert labels underhubert_label_dir
.Click for full traceback
task = SpeechT5Task.setup_task(checkpoint['cfg']['task'])
Click for full traceback
``` --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Input In [17], inWould be very grateful to get some insight on these two questions:
Many thanks for your help!