serp-ai / bark-with-voice-clone

🔊 Text-prompted Generative Audio Model - With the ability to clone voices
https://serp.ai/tools/bark-text-to-speech-ai-voice-clone-app
Other
3.11k stars 417 forks source link

Confused model loading behavior, LORA is not used at all? #48

Open renxiangnan opened 1 year ago

renxiangnan commented 1 year ago

Hi guys, in generation.py, I noticed following code snippet as below. Looks like LORA is not used for inference at all, or is there anything I missed ? Thank you

 unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    unwanted_suffixes = [
        "lora_right_weight",
        "lora_left_weight",
        "lora_right_bias",
        "lora_left_bias",
    ]
    for k, v in list(state_dict.items()):
        for suffix in unwanted_suffixes:
            if k.endswith(suffix):
                state_dict.pop(k)
    # super hacky - should probably refactor this
    if state_dict.get('lm_head.0.weight', None) is not None:
        state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight')
    if state_dict.get('lm_heads.0.0.weight', None) is not None:
        state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight')
    if state_dict.get('lm_heads.1.0.weight', None) is not None:
        state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight')
    if state_dict.get('lm_heads.2.0.weight', None) is not None:
        state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight')
    if state_dict.get('lm_heads.3.0.weight', None) is not None:
        state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight')
    if state_dict.get('lm_heads.4.0.weight', None) is not None:
        state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight')
    if state_dict.get('lm_heads.5.0.weight', None) is not None:
        state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight')
    if state_dict.get('lm_heads.6.0.weight', None) is not None:
        state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight')
dagshub[bot] commented 1 year ago

Join the discussion on DagsHub!

francislabountyjr commented 1 year ago

The LoRA in this case is just used for training and then is merged back to the model. Do you want LoRA for the plug and play adaptors? If so I can rework it to go that route instead when I can get to it

renxiangnan commented 1 year ago

The LoRA in this case is just used for training and then is merged back to the model. Do you want LoRA for the plug and play adaptors? If so I can rework it to go that route instead when I can get to it

I appreciate your detailed explanation, which helped me gain a clearer understanding of the topic. It would be excellent if you could incorporate the plug-and-play adaptors as part of this setup. If I understand correctly (please correct me if I am wrong), given that LoRA is designed to prevent catastrophic forgetting, I believe it might be beneficial to consider not merging the weights back into the model and instead retain this option. Doing so could also lead to a less hacky when it comes to loading the model.