BUTSpeechFIT / VBx

Variational Bayes HMM over x-vectors diarization
252 stars 57 forks source link

Speaker loop elimination #28

Closed videodanchik closed 3 years ago

videodanchik commented 3 years ago

Hi guys I have a small update for your main VB_diarization function. This update almost fully removes your inner speaker loop and also speeds up things a little bit. I tried to make a pull request, but haven't found an opportunity to push new branch, anyway I suggest replacing the following lines:

for ii in range(maxIters):
    L = 0 # objective function (37) (i.e. VB lower-bound on the evidence)
    Ns = gamma.sum(0)                                     # bracket in eq. (34) for all 's'
    VtiEFs = gamma.T.dot(VtiEF)                           # eq. (35) except for \Lambda_s^{-1} for all 's'
    for sid in range(maxSpeakers):
        invL = np.linalg.inv(np.eye(R) + Ns[sid]*VtiEV*Fa/Fb) # eq. (34) inverse
        a = invL.dot(VtiEFs[sid])*Fa/Fb                                        # eq. (35)
        # eq. (29) except for the prior term \ln \pi_s. Our prior is given by HMM
        # trasition probability matrix. Instead of eq. (30), we need to use
        # forward-backwar algorithm to calculate per-frame speaker posteriors,
        # where 'lls' plays role of HMM output log-probabilities
        lls[:,sid] = Fa * (G + VtiEF.dot(a) - 0.5 * ((invL+np.outer(a,a)) * VtiEV).sum())
        L += Fb* 0.5 * (logdet(invL) - np.sum(np.diag(invL) + a**2, 0) + R)

with the following code:

for ii in range(maxIters):
    L = 0 # objective function (37) (i.e. VB lower-bound on the evidence)
    Ns = np.sum(gamma, axis=0)[:, np.newaxis, np.newaxis]               # bracket in eq. (34) for all 's'
    VtiEFs = gamma.T.dot(VtiEF)[:, :, np.newaxis]                       # eq. (35) except for \Lambda_s^{-1} for all 's'
    invLs = np.linalg.inv(np.eye(R)[np.newaxis, :, :] + Ns * VtiEV[np.newaxis, :, :] * Fa / Fb)  # eq. (34) inverse
    a = np.matmul(invLs, VtiEFs).squeeze(axis=-1) * Fa / Fb  # eq. (35)
    # eq. (29) except for the prior term \ln \pi_s. Our prior is given by HMM
    # trasition probability matrix. Instead of eq. (30), we need to use
    # forward-backwar algorithm to calculate per-frame speaker posteriors,
    # where 'lls' plays role of HMM output log-probabilities
    lls = Fa * (
            G[:, np.newaxis] + VtiEF.dot(a.T) - 0.5 *
            ((invLs + np.matmul(a[:, :, np.newaxis], a[:, np.newaxis, :])) * VtiEV[np.newaxis, :, :]).sum(axis=(1, 2)))
    for sid in range(maxSpeakers):
        L += Fb* 0.5 * (logdet(invLs[sid]) - np.sum(np.diag(invLs[sid]) + a[sid]**2, 0) + R)

as you can see I haven't implemented vectorized form for the last line in the speaker loop because this replacement requires additional modification of the logdet function. To make things work for multiple matrices spl.cholesky should be replaced with np.linalg.cholesky and for some reason this slows down the processing speed (probably because numpy version involves some additional checks before doing Cholesky decomposition). Finally lls = np.zeros_like(gamma) can be removed. What do you think about it?

fnlandini commented 3 years ago

Hello, sorry for the delay. I am currently busy with a different project that is taking all my time. Thank you for the suggestions. I will take a look to your proposed changes and try them out later this year. I apologize for the delay again. I'll keep the issue open until then

fnlandini commented 3 years ago

Hi @videodanchik After checking I got the same results, I have accepted the pull request. For some reason, I needed to update the numpy version to import fastcluster successfully. Perhaps it is not strictly necessary (some problem on my environment) but it will not harm to have a newer numpy as requirement. Thanks a lot for the contribution! Federico