microsoft / CLAP

Learning audio concepts from natural language supervision
MIT License
455 stars 35 forks source link

Inconsistent results, Mean Square Error is high between runs. #24

Closed hykilpikonna closed 11 months ago

hykilpikonna commented 11 months ago

This is based on the code in the master branch of this repository. Two runs of get_audio_embeddings on the same audio file does not produce identical results, and the MSE is 0.1367.

image

My testing code:

# Load model (Choose between versions '2022' or '2023')
from CLAPWrapper import CLAPWrapper as CLAP 
import torch

with torch.no_grad():

    clap_model = CLAP("/Users/azalea/Downloads/CLAP_weights_2023.pth", version = '2023', use_cuda=False)

    # Extract text embeddings
    # text_embeddings = clap_model.get_text_embeddings(class_labels: List[str])

    audio_embeddings_1 = clap_model.get_audio_embeddings(["/Users/azalea/Downloads/test.wav"])
    audio_embeddings_2 = clap_model.get_audio_embeddings(["/Users/azalea/Downloads/test.wav"])

    print(audio_embeddings_1)
    print(audio_embeddings_2)

    # Compute mean square error
    mse = torch.mean((audio_embeddings_1 - audio_embeddings_2)**2)
    print(mse)

OS: macOS Sonoma 14.0 Python: Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:39:40) [Clang 15.0.7 ]

The audio file metadata:

> ffprobe test.wav -hide_banner
Input #0, wav, from 'test.wav':
  Metadata:
    artist          : ハンバート ハンバート
    date            : 2006
    title           : 日が落ちるまで
    album           : 道はつづく
    track           : 7
    encoder         : Lavf60.3.100
  Duration: 00:04:56.49, bitrate: 1411 kb/s
  Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 44100 Hz, 2 channels, s16, 1411 kb/s
soham97 commented 11 months ago

Hi @hykilpikonna, the code supports files up to 7 seconds in length. So if you pass a file greater than 7 seconds (in your case it's ~5mins) then it will randomly sample a 7-second segment and provide embeddings/predictions on that segment. I think this is the cause of high variance in predictions.

I would recommend either chunking your file in 7-second (or lower) files or updating CLAPWrapper.py to chunk and accumulate predictions.

hykilpikonna commented 11 months ago

Hi @hykilpikonna, the code supports files up to 7 seconds in length. So if you pass a file greater than 7 seconds (in your case it's ~5mins) then it will randomly sample a 7-second segment and provide embeddings/predictions on that segment. I think this is the cause of high variance in predictions.

I would recommend either chunking your file in 7-second (or lower) files or updating CLAPWrapper.py to chunk and accumulate predictions.

Thank you for the clarification.