Jungjee / RawNet

Official repository for RawNet, RawNet2, and RawNet3
MIT License
357 stars 55 forks source link

trained model for RawNet2_modified and RawNet2 #21

Closed hdubey closed 2 years ago

hdubey commented 2 years ago

Hi, Can you share trained models for RawNet2 and RawNet2_modified for quick testing. Do you have script for extracting speaker embeddings from a wav file?

ac-alpha commented 2 years ago

@hdubey you can find the weights for RawNet2 here.

You can use this script for quick testing and getting the embeddings. Make sure that you have this model definition in the directory you are running this script.

from tqdm import tqdm
from collections import OrderedDict

import os
import argparse
import json
import numpy as np
import glob
import pickle

import torch
import torch.nn as nn
from torch.utils import data

from dataloader import *
from model_RawNet2 import RawNet2
from parser import get_args
from trainer import *
from utils import *
from model_RawNet2_original_code import *
from pydub import AudioSegment

load_model_dir = "Pre-trained_model/rawnet2_best_weights.pt"
test_wav_path1 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10270/5r0dWxy17C8/00001.wav"
test_wav_path2 = "/root/host/ml-speaker-verification/data/vox1/vox1_test_wav/id10278/d6WJf6TOoIQ/00001.wav"

test_wav_path3 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/01dfn2spqyE/00001.m4a"
test_wav_path4 = "/root/host/ml-speaker-verification/data/vox2/vox2_test_m4a/id00017/8_a6O3vdlU0/00021.m4a"

def cos_sim(a,b) :
    return np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b))

def read_wav_and_get_clip_tensor(test_wav_path, nb_samp, window_size, wav_file = True):

    if not wav_file:
        X = AudioSegment.from_file(test_wav_path)
        X = X.get_array_of_samples()
        X = np.array(X)
    else:
        X, _ = sf.read(test_wav_path)
    X = X.astype(np.float64)
    X = _normalize_scale(X).astype(np.float32)
    X = X.reshape(1,-1)

    nb_time = X.shape[1]
    list_X = []
    nb_time = X.shape[1]
    if nb_time < nb_samp:
        nb_dup = int(nb_samp / nb_time) + 1
        list_X.append(np.tile(X, (1, nb_dup))[:, :nb_samp][0])
    elif nb_time > nb_samp:
        step = nb_samp - window_size
        iteration = int( (nb_time - window_size) / step ) + 1
        for i in range(iteration):
            if i == 0:
                list_X.append(X[:, :nb_samp][0])
            elif i < iteration - 1:
                list_X.append(X[:, i*step : i*step + nb_samp][0])
            else:
                list_X.append(X[:, -nb_samp:][0])
    else :
        list_X.append(X[0])
    return torch.from_numpy(np.asarray(list_X))

def get_embedding_from_clip_tensor(clip_tensor, model, device):
    model.eval()

    with torch.set_grad_enabled(False):
        #1st, extract speaker embeddings.
        l_embeddings = []
        l_code = []
        mbatch = clip_tensor
        mbatch = mbatch.unsqueeze(1)
#         print("Batch size = {}".format(mbatch.size()))
        for batch in mbatch:
            batch = batch.to(device)
            code = model(x = batch, is_test=True)
#             print("Code size = {}".format(code.size()))
            l_code.extend(code.cpu().numpy())
        embedding = np.mean(l_code, axis=0)
#         print("Embedding shape = {}".format(embedding.shape))
        return embedding

def _normalize_scale(x):
    '''
    Normalize sample scale alike SincNet.
    '''
    return x/np.max(np.abs(x))

def main_test():
    #parse arguments
    args = get_args()

    wav_path = args.wav_path
    save_path = args.sav_path
    direc_level = args.direc_level
    wav_file = True if args.wav_file==1 else False

    ## Number of speakers in VoxCeleb2 dataset. 
    ## Not used in computing embeddings but should still be there. 
    ## Do not comment this.
    args.model['nb_classes'] = 6112 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #device setting
    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    print('Device: {}'.format(device))

    model = RawNet(args.model, device).to(device)
    model.load_state_dict(torch.load(load_model_dir))
    nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    nb_samp = args.model["nb_samp"]
    window_size = args.window_size
    print('nb_params: {}'.format(nb_params))

    X1 = read_wav_and_get_clip_tensor(test_wav_path3, nb_samp, window_size, wav_file)
    emb_X1 = get_embedding_from_clip_tensor(X1, model, device)

    X2 = read_wav_and_get_clip_tensor(test_wav_path4, nb_samp, window_size, wav_file)
    emb_X2 = get_embedding_from_clip_tensor(X2, model, device)

    sim_score = cos_sim(emb_X1, emb_X2)
    print("Similarity = {}".format(sim_score))

if __name__ == '__main__':
    main_test()
Jungjee commented 2 years ago

@ac-alpha thanks for the reply :) I'll close this

hdubey commented 2 years ago

@ac-alpha thanks. using above script and provided model leads to following errors. Is it RawNet or RawNet2 or Rawnet2_modified? RuntimeError: Error(s) in loading state_dict for RawNet: Unexpected key(s) in state_dict: "block2.0.conv_downsample.weight", "block2.0.conv_downsample.bias". size mismatch for block2.0.bn1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for block2.0.bn1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for block2.0.bn1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for block2.0.bn1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for block2.0.conv1.weight: copying a param with shape torch.Size([256, 128, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3]).

Jungjee commented 2 years ago

Closing this now as I have uploaded RawNet3 and a script to extract speaker embedding from any 16k 16bit mono utterance