Open sausage-333 opened 1 year ago
Thanks for your attention to our work! @sausage-333
We used a setting similar to that of HuBERT. You can refer to fairseq’s examples/hubert/config/finetune/base_10h.yaml
with corresponding instructions. Additionally, you need to set
task.normalize: true
, since LightHuBERT is trained with normalized waveformsw2v_path: /path/to/lighthubert
, specifying the pre-trained LightHuBERT checkpointTo load the model successfully, some modifications are needed and we will include them in this repo soon. But here we provide a temporary but quick solution. You can make the following changes to HubertEncoder's __init__ in fairseq/models/hubert/hubert_asr.py
:
# pretrain_task = tasks.setup_task(w2v_args.task)
# if state is not None and "task_state" in state:
# # This will load the stored "dictionaries" object
# pretrain_task.load_state_dict(state["task_state"])
# else:
# pretrain_task.load_state_dict(task.state_dict())
# model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
# if state is not None and not cfg.no_pretrained_weights:
# # set strict=False because we omit some modules
# model.load_state_dict(state["model"], strict=False)
# model.remove_pretraining_modules()
# super().__init__(pretrain_task.source_dictionary)
# d = w2v_args.model.encoder_embed_dim
from lighthubert import LightHuBERT, LightHuBERTConfig
lighthubert_cfg = LightHuBERTConfig(state['cfg']['model'])
lighthubert_cfg.supernet_type = "base"
model = LightHuBERT(lighthubert_cfg)
model.load_state_dict(state["model"], strict=False)
model.remove_pretraining_modules()
subnet = {
'layer_num': 12,
'embed_dim': 640,
'heads_num': [10,] * 12,
'atten_dim': [640,] * 12,
'ffn_embed': [2560,] * 12,
'slide_wsz': ['global'] * 12,
}
model.set_sample_config(subnet)
total_params = model.calc_sampled_param_num()
print(f"target pre-trained subnet ({total_params:,} Params): {subnet}")
super().__init__(None)
d = subnet['embed_dim']
Above is the Base subnet setting. You can specify supernet_type
and subset
by passing additional arguments.
Hello, I have a question about 10-hour ASR fine-tuning in your paper.
Can you give me a procedure about this experiment? (or the link I can refer) I just want to conduct the my own experiments for 10-hour ASR fine-tuning using fairseq.
Thanks!