LAION-AI / CLAP

Contrastive Language-Audio Pretraining
https://arxiv.org/abs/2211.06687
Creative Commons Zero v1.0 Universal
1.41k stars 135 forks source link

Acc drop after converting "HTSAT-base" type to huggingface model #126

Open happylittlecat2333 opened 1 year ago

happylittlecat2333 commented 1 year ago

Question Description

I want to use huggingface model style but only find "laion/clap-htsat-unfused" and "laion/clap-htsat-fused" in huggingface Models. However, I wish to use the music CLAP model, such as music_speech_epoch_15_esc_89.25.pt, so I find https://github.com/huggingface/transformers/blob/main/src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py to convert the clap model. But I find that the newly update model are based on HTSAT-base model, the hidden_size and patch_embeds_hidden_size are different. So I revise the convert_clap_original_pytorch_to_hf.py to below. But after test three model( including HTSAT-base and HTSAT-tiny based model), I find Acc drop for HTSAT-base model, can you please help me find out the problem, and maybe convert and upload huggingface model style of your newly updated model.

My revised convert_clap_original_pytorch_to_hf.py

import argparse
import re

import torch
# from CLAP import create_model
from laion_clap.clap_module import create_model

from transformers import AutoFeatureExtractor, ClapConfig, ClapModel, ClapAudioConfig, ClapProcessor

KEYS_TO_MODIFY_MAPPING = {
    "text_branch": "text_model",
    "audio_branch": "audio_model.audio_encoder",
    "attn": "attention.self",
    "self.proj": "output.dense",
    "attention.self_mask": "attn_mask",
    "mlp.fc1": "intermediate.dense",
    "mlp.fc2": "output.dense",
    "norm1": "layernorm_before",
    "norm2": "layernorm_after",
    "bn0": "batch_norm",
}

processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")

# ADDED
CLAP_AUDIO_CONFIG_DICT = {
    "HTSAT-tiny": {},
    "HTSAT-base": {
        "hidden_size": 1024,
        "patch_embeds_hidden_size": 128,
    }
}

def init_clap(checkpoint_path, amodel="HTSAT-tiny", enable_fusion=False):
    model, model_cfg = create_model(
        amodel,
        "roberta",
        checkpoint_path,
        precision="fp32",
        device="cuda:0" if torch.cuda.is_available() else "cpu",
        enable_fusion=enable_fusion,
        fusion_type="aff_2d" if enable_fusion else None,
    )
    return model, model_cfg

def rename_state_dict(state_dict):
    model_state_dict = {}

    sequential_layers_pattern = r".*sequential.(\d+).*"
    text_projection_pattern = r".*_projection.(\d+).*"

    for key, value in state_dict.items():
        # check if any key needs to be modified
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
            if key_to_modify in key:
                key = key.replace(key_to_modify, new_key)

        if re.match(sequential_layers_pattern, key):
            # replace sequential layers with list
            sequential_layer = re.match(sequential_layers_pattern, key).group(1)

            key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
        elif re.match(text_projection_pattern, key):
            projecton_layer = int(re.match(text_projection_pattern, key).group(1))

            # Because in CLAP they use `nn.Sequential`...
            transformers_projection_layer = 1 if projecton_layer == 0 else 2

            key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")

        if "audio" and "qkv" in key:
            # split qkv into query key and value
            mixed_qkv = value
            qkv_dim = mixed_qkv.size(0) // 3

            query_layer = mixed_qkv[:qkv_dim]
            key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
            value_layer = mixed_qkv[qkv_dim * 2 :]

            model_state_dict[key.replace("qkv", "query")] = query_layer
            model_state_dict[key.replace("qkv", "key")] = key_layer
            model_state_dict[key.replace("qkv", "value")] = value_layer
        else:
            model_state_dict[key] = value

    return model_state_dict

def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, amodel, enable_fusion=False):
    clap_model, clap_model_cfg = init_clap(checkpoint_path, amodel=amodel, enable_fusion=enable_fusion)

    clap_model.eval()
    state_dict = clap_model.state_dict()
    state_dict = rename_state_dict(state_dict)

    # ADDED
    clap_audio_config = CLAP_AUDIO_CONFIG_DICT[amodel]

    transformers_config = ClapConfig(audio_config=clap_audio_config)
    transformers_config.audio_config.enable_fusion = enable_fusion
    model = ClapModel(transformers_config)

    # ignore the spectrogram embedding layer
    model.load_state_dict(state_dict, strict=False)

    model.save_pretrained(pytorch_dump_folder_path)
    transformers_config.save_pretrained(pytorch_dump_folder_path)
    processor.save_pretrained(pytorch_dump_folder_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    parser.add_argument("--amodel", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
    parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
    args = parser.parse_args()

    convert_clap_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.amodel, args.enable_fusion)

convert script:

python convert_clap_original_pytorch_to_hf.py \
    --pytorch_dump_folder_path ./clap-htsat-base-unfused-music-audioset \
    --checkpoint_path ./pretrained_model/music_audioset_epoch_15_esc_90.14.pt \
    --config_path ./clap-htsat-base-unfused-music-audioset/config.json \
    --amodel HTSAT-base

python convert_clap_original_pytorch_to_hf.py \
    --pytorch_dump_folder_path ./clap-htsat-base-unfused-music-speech-audioset \
    --checkpoint_path ./pretrained_model/music_speech_audioset_epoch_15_esc_89.98.pt \
    --config_path ./clap-htsat-base-unfused-music-speech-audioset/config.json \
    --amodel HTSAT-base 

python convert_clap_original_pytorch_to_hf.py \
    --pytorch_dump_folder_path ./630k-audioset-best \
    --checkpoint_path ./pretrained_model/630k-audioset-best.pt \
    --config_path ./630k-audioset-best/config.json \
    --amodel HTSAT-tiny

My evalute on ESC50 (adopted by your eval code)

import glob
import json
import torch
import numpy as np
from transformers import ClapModel, ClapProcessor
import librosa

device = torch.device('cuda:0')

# download https://drive.google.com/drive/folders/1scyH43eQAcrBz-5fAw44C6RNBhC3ejvX?usp=sharing and extract ./ESC50_1/test/0.tar to ./ESC50_1/test/
esc50_test_dir = './ESC50_1/test/*/'
class_index_dict_path = './class_labels/ESC50_class_labels_indices_space.json'

# Load the model (for different converted model)
pretrained_model_path = "./clap-htsat-base-unfused-music-speech-audioset"
# pretrained_model_path = "./clap-htsat-base-unfused-music-audioset"
# pretrained_model_path = "./630k-audioset-best"
# pretrained_model_path = "laion/clap-htsat-unfused"
processor = ClapProcessor.from_pretrained(pretrained_model_path)
model = ClapModel.from_pretrained(pretrained_model_path)

# Get the class index dict
class_index_dict = {v: k for v, k in json.load(open(class_index_dict_path)).items()}

# Get all the data
audio_files = sorted(glob.glob(esc50_test_dir + '**/*.flac', recursive=True))
json_files = sorted(glob.glob(esc50_test_dir + '**/*.json', recursive=True))

print("audio_files: ", len(audio_files))
print("json_files: ", len(json_files))

ground_truth_idx = [class_index_dict[json.load(open(jf))['tag'][0]] for jf in json_files]

with torch.no_grad():
    ground_truth = torch.tensor(ground_truth_idx).view(-1, 1)

    # Get text features
    all_texts = ["This is a sound of " + t for t in class_index_dict.keys()]

    inputs = processor(text=all_texts, return_tensors="pt", padding=True)
    text_embed = model.get_text_features(**inputs)
    print("text_embed: ", text_embed.shape)

    audio_input = []
    for audio_file in audio_files:
        audio_waveform, _ = librosa.load(audio_file, sr=48000)
        audio_input.append(audio_waveform)

    inputs = processor(audios=audio_input, return_tensors="pt", padding=True, sampling_rate=48000)
    audio_embed = model.get_audio_features(**inputs)

    print("audio_embed: ", audio_embed.shape)

    # audio_embed = model.get_audio_embedding_from_filelist(x=audio_files)

    ranking = torch.argsort(torch.tensor(audio_embed) @ torch.tensor(text_embed).t(), descending=True)
    preds = torch.where(ranking == ground_truth)[1]
    preds = preds.cpu().numpy()

    metrics = {}
    metrics[f"mean_rank"] = preds.mean() + 1
    metrics[f"median_rank"] = np.floor(np.median(preds)) + 1
    for k in [1, 5, 10]:
        metrics[f"R@{k}"] = np.mean(preds < k)
    # map@10
    metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))

    print(
        f"Zeroshot Classification Results: "
        + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
    )

Evaluate Result

Therefore, we can see that HTSAT-base type have Acc drop after converting to huggingface type, could you please help us figure out this bug, and maybe upload huggingface version of CLAP model for music_speech_epoch_15_esc_89.25.pt, music_speech_audioset_epoch_15_esc_89.98.pt? Thanks!

RetroCirce commented 1 year ago

Hi!

Thank you for your question. Unfortunately, the Hugging Face implementation is not handled by us (authors in CLAP) but by Hugging Face researchers (I think they are Younes Belkada and Arthur Zucker).

It would be better if you could open this issue under the Hugging Face transformers repo. Of course, I believe our pip library could have the same function. So if your code is not largely base on the Hugging Face transformers, you are welcome to use our pip library (see readme for more details).

happylittlecat2333 commented 1 year ago

Thanks for your reply! Because my work is largely depend on huggingface style model, like laion/clap-htsat-unfused,so it would be convinent to change to another clap model like music-clap with same code. I will open issue under huggingface transformers repo. Thanks a lot!

PS: since there is open huggingface model like laion/clap-htsat-unfused under Laion, I think it will be very convenient for users if your team convert other clap model into huggingface style and upload to Laion, just like the CLIP models collection under Laion. :)