yizhilll / MERT

Official implementation of the paper "Acoustic Music Understanding Model with Large-Scale Self-supervised Training".
Apache License 2.0
301 stars 18 forks source link

Which embedding to use for the downstream task of genre classification? #5

Closed elloza closed 11 months ago

elloza commented 1 year ago

Hello,

First of all congratulations for your work!

I was doing a simple example to get embeddings of songs and I was wondering what is the best final embedding for a genre classification task.

from transformers import Wav2Vec2FeatureExtractor, AutoModel
import torch
import torchaudio
from torch import nn
import torchaudio.transforms as T
from datasets import load_dataset

# Loading the model and processor
model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True)

# Load demo audio dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
dataset = dataset.sort("id")

# Replace with loading your audio file
audio_file_path = "232683.mp3"
# Load and preprocess the audio file
audio_path = audio_file_path
resample_rate = processor.sampling_rate

# Load and preprocess the audio file
waveform, sampling_rate = torchaudio.load(audio_path, normalize=True)
if resample_rate != sampling_rate:
    resampler = T.Resample(sampling_rate, resample_rate)
    waveform = resampler(waveform)

# Crop the audio to 30 seconds
target_duration = 30  # seconds
target_num_frames = int(target_duration * resample_rate)
if waveform.size(1) > target_num_frames:
    waveform = waveform[:, :target_num_frames]

waveform = waveform.mean(dim=0, keepdim=True)

# Flatten the mono_waveform tensor to a single dimension
waveform = waveform.view(-1)

# Extract features using the Wav2Vec2 processor
inputs = processor(waveform.numpy(), sampling_rate=resample_rate, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

# take a look at the output shape, there are 25 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
print(all_layer_hidden_states.shape) # [25 layer, Time steps, 1024 feature_dim]

# for utterance level classification tasks, you can simply reduce the representation in time
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
print(time_reduced_hidden_states.shape) # [25, 1024]

# you can even use a learnable weighted average representation
aggregator = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1)
weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
print(weighted_avg_hidden_states.shape) # [1024]`

I had some questions:

Thank you very much in advance,

Thanks a lot,

yizhilll commented 1 year ago

@elloza thanks for your question.

Since there was a similar issue raised on our huggingface hub, I recommend you refer to the answers in that issue.

For your two questions specifically:

  1. Yes, we process the data into mono audios
  2. How many layers of "intermedidate" MLP layers between the transformer output and the classifier would be a hyper-parameter depending on your task. Usually we'll search between 0-3 layers for classificaiton.