mechanicalsea / lighthubert

LightHuBERT: Lightweight and Configurable Speech Representation Learning with Once-for-All Hidden-Unit BERT
MIT License
69 stars 6 forks source link

10 hours ASR Fine-tuning #5

Open sausage-333 opened 1 year ago

sausage-333 commented 1 year ago

Hello, I have a question about 10-hour ASR fine-tuning in your paper.

Can you give me a procedure about this experiment? (or the link I can refer) I just want to conduct the my own experiments for 10-hour ASR fine-tuning using fairseq.

Thanks!

P1ping commented 1 year ago

Thanks for your attention to our work! @sausage-333

We used a setting similar to that of HuBERT. You can refer to fairseq’s examples/hubert/config/finetune/base_10h.yaml with corresponding instructions. Additionally, you need to set

To load the model successfully, some modifications are needed and we will include them in this repo soon. But here we provide a temporary but quick solution. You can make the following changes to HubertEncoder's __init__ in fairseq/models/hubert/hubert_asr.py:

        # pretrain_task = tasks.setup_task(w2v_args.task)
        # if state is not None and "task_state" in state:
        #     # This will load the stored "dictionaries" object
        #     pretrain_task.load_state_dict(state["task_state"])
        # else:
        #     pretrain_task.load_state_dict(task.state_dict())

        # model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True)
        # if state is not None and not cfg.no_pretrained_weights:
        #     # set strict=False because we omit some modules
        #     model.load_state_dict(state["model"], strict=False)

        # model.remove_pretraining_modules()

        # super().__init__(pretrain_task.source_dictionary)

        # d = w2v_args.model.encoder_embed_dim

        from lighthubert import LightHuBERT, LightHuBERTConfig
        lighthubert_cfg = LightHuBERTConfig(state['cfg']['model'])
        lighthubert_cfg.supernet_type = "base"
        model = LightHuBERT(lighthubert_cfg)
        model.load_state_dict(state["model"], strict=False)
        model.remove_pretraining_modules()
        subnet = {
            'layer_num': 12,
            'embed_dim': 640,
            'heads_num': [10,] * 12,
            'atten_dim': [640,] * 12,
            'ffn_embed': [2560,] * 12,
            'slide_wsz': ['global'] * 12,
        }
        model.set_sample_config(subnet)
        total_params = model.calc_sampled_param_num()
        print(f"target pre-trained subnet ({total_params:,} Params): {subnet}")
        super().__init__(None)
        d = subnet['embed_dim']

Above is the Base subnet setting. You can specify supernet_type and subset by passing additional arguments.