LAION-AI / CLAP

Contrastive Language-Audio Pretraining
https://arxiv.org/abs/2211.06687
Creative Commons Zero v1.0 Universal
1.42k stars 137 forks source link

Unable to reproduce zeroshot classification results #133

Open cyrusvahidi opened 12 months ago

cyrusvahidi commented 12 months ago

Overview

I have attempted to reproduce the zeroshot classification results for ESC-50 outlined in the publication Large-scale contrastive language-audio pretraining with feature fusion and keyword-to-caption augmentation.

In the paper, zeroshot classification accuracy (top-1) for the best model (K2C aug) is reported at 91.0%. I assume that this is the 630k-audioset-best.pt checkpoint.

Reproduce

I use the set of 50 unique captions in the test dataset, which are found in the text attr of each example's json file, e.g. "The sound of the crow".

Here's the loader for ESC-50:

class ESC50Dataset(Dataset):
    def __init__(
        self,
        path_to_esc50="./data/ESC50",
        split="test",
        audio_len = 480000,
    ):
        super().__init__()
        self.data_path = Path(path_to_esc50)
        self.audio_len = audio_len

        self.audio_files = sorted(glob.glob(str(self.data_path / split / "*.flac")))
        self.meta_files = sorted(glob.glob(str(self.data_path / split / "*.json")))
        assert len(self.audio_files) == len(self.meta_files), "Number of audio files and meta files must match"
        assert [osp.splitext(osp.basename(x))[0] for x in self.audio_files] == [osp.splitext(osp.basename(x))[0] for x in self.meta_files], "Audio files and meta files must have the same names"

        self.tags = []
        self.texts = []
        for f in self.meta_files:
            with open(f, 'r') as json_file:
                data = json.load(json_file)
                self.tags.append(data["tag"][0])
                self.texts.append(data["text"][0])

    def __getitem__(self, idx):
        x, _ = load_audio_torch(self.audio_files[idx], target_sr=48000, mono=True)
        x = random_slice(x, self.audio_len)
        return x, self.texts[idx]

    def __len__(self):
        return len(self.audio_files)

And the zeroshot retrieval script:

import os

import torch
import laion_clap

from data.loaders import ESC50Dataset

ckpt_path = "CLAP_checkpoints/laion_clap/"
model_params = {"ckpt": "630k-audioset-best.pt", "amodel": "HTSAT-tiny"}
model = laion_clap.CLAP_Module(enable_fusion=False, amodel=model_params["amodel"])
model.load_ckpt(os.path.join(ckpt_path, model_params["ckpt"]))

dataset = ESC50Dataset()
texts = list(set(dataset.texts)) # get the unique texts, e.g "The sound of the crow"

# get the text embeddings for each tag
z_text = torch.cat([torch.tensor(model.get_text_embedding([t, t])[0:1]) for t in texts])

z_audio = []
text_idxs = []
for item in dataset:
    x, text = item
    idx = texts.index(text) # get the index of this example's text
    text_idxs.append(idx)
    z_audio.append(torch.tensor(model.get_audio_embedding_from_data(x.numpy()))) # get its CLAP audio embedding
z_audio = torch.cat(z_audio)
sim = model.model.logit_scale_a.cpu() * z_audio @ z_text.T # compute pairwise dot products

# top-1 accuracy
acc = float(torch.sum(torch.argmax(sim, dim=1) == torch.tensor(text_idxs)) / len(sim))
print(f"Accuracy: {acc}")
Accuracy: 0.6025000214576721 

Hopefully I am missing something significant?

cyrusvahidi commented 11 months ago

Ok I managed to reproduce:

Zeroshot Classification Results: mean_rank: 2.7344      median_rank: 1.0000     R@1: 0.5925     R@5: 0.8981     R@10: 0.9525    mAP@10: 0.7200
Accuracy: 0.5925 over 1600 samples

It seems top-5 accuracy was reported in the paper. I was confused, as Section 4.3 of the paper states "We use top-1 accuracy as the metric.".