Open j4sonzhao opened 2 years ago
It looks like the state dictionary was not saved properly, and so it cannot be loaded.
@wnhsu would appreciate help on how to load this! I've been trying to modify the config file in the state dictionary (state["config"]) to point to the correct data dictionary, but I keep running into issues.
@j4sonzhao thanks for pointing out the bug!
To answer your question about why it needs the path, it was because it needs the dictionary to determine the output size (how many clusters are there) when initializing the pre-trained HuBERT (though that prediction head is not used in the fine-tuning stage...)
We will fix that to make it not load the pre-training dictionary. In the meanwhile, if we want a quick fix, you can do the following:
dict.lyr9.km500.txt
in some directory, say /tmp/hubert_labels/dict.lyr9.km500.txt
, which contains 500 lines as follows
0 1
1 1
...
499 1
import torch
state = torch.load(old_checkpoint_path)
state['cfg']['model']['w2v_args']['task']['label_dir'] = "/tmp/hubert_labels"
torch.save(state, new_checkpoint_path)
Now the new_checkpoint_path
should work with fairseq.checkpoint_utils.load_model_ensemble_and_task([new_checkpoint_path])
The current branch with [this commit] (https://github.com/pytorch/fairseq/commit/272c4c5197250997148fb12c0db6306035f166a4) should fix the bug of requiring to load pre-training dict when loading a fine-tuned checkpoint, and several issues introduced by recent commits (miss config, etc.)
Hi there,
I am trying to load the downloaded Hubert Large with 960hr finetuning from here: https://github.com/pytorch/fairseq/tree/main/examples/hubert
I downloaded the model, stored the checkpoint, and am trying to run
However, I am running into an error, where it is trying to load a dictionary from a incorrect location. It looks like it is some issue related to HubertPretrainingTask Config, where the default settings are wrong:
I am confused though; why does hubert need these paths in the first place just to initialize itself, and how can I change this?
In general, I am trying to decode asr with hubert, and this is the specific issue I am running into.