google / uis-rnn

This is the library for the Unbounded Interleaved-State Recurrent Neural Network (UIS-RNN) algorithm, corresponding to the paper Fully Supervised Speaker Diarization.
https://arxiv.org/abs/1810.04719
Apache License 2.0
1.55k stars 320 forks source link

Model predicts new cluster for each input after calling load() #13

Closed dalonlobo closed 5 years ago

dalonlobo commented 5 years ago

Hi,

I've loaded the saved model trained on the custom dataset using model.load. When I'm predicting test set with model.predict, Instead of labels, I'm getting the sequence of numbers, which looks like it starts from the length of sequences passed to predict. Following is a screenshot for your reference.

image

Thank you in advance.

wq2012 commented 5 years ago

UIS-RNN is a supervised sequence clustering algorithm, not classification.

Different numbers indicate different clusters. The only purpose of labels is to distinguish from each other.

wq2012 commented 5 years ago

BTW your predictions looks incorrect. Looks like it's treating each input as a separate cluster.

Did your training converge? Are you training/testing on the toy dataset under data/?

dalonlobo commented 5 years ago

I've trained it using vctk dataset (total of 186834 embeddings), the loss looks like the below image:

image

I understand that it's clustering, but if you see the sequence of predicted outputs, its consecutive numbers, which starts with the length of the input sequence.

I've labeled 5minutes of 2 interview videos from youtube and used it as test dataset.

dalonlobo commented 5 years ago

The predicted values look acceptable when the model enters the testing phase immediately after training, using the same demo.py code. Below figure:

image

But if I load the saved model and run inference, it gives those contiguous sequences like below:

image

The model does not look like its converging very well, any suggestion on tuning hyper params or it needs more data?

77281900000 commented 5 years ago

@dalonlobo I have the same problem.And I try to run the toy data use in the demo again,and i got a different result.I describe this problem in detail in the latest issues.Do you have the same problem?And could you tell me your solutions if you have any ideas?

wq2012 commented 5 years ago

Thanks @dalonlobo and @gaodihe for reporting this. Seems like a bug. We will look at it.

wq2012 commented 5 years ago

@dalonlobo About the training, supervised diarization is a hard problem, so you may need more data. Also, you were training with VCTK, but I don't think VCTK includes speaker turns (aren't they single speaker utterances?).

The purpose of UIS-RNN is to learn the patterns of speaker turns from examples. But if you are making those speaker turns by purely manually concatenating them randomly, I guess there's not much you could learn.

Please try to train the model on some real multi-speaker utterances that include natural speaker turns.

Also, you may want to tune the learning rate and network size.

wq2012 commented 5 years ago

Hi @dalonlobo , could you let us know your versions of python, numpy, and pytorch? That could help us identify the problem.

dalonlobo commented 5 years ago

Thank you @wq2012 Sorry for the delayed response!

VCTK has 109 native speakers, I stitched them artificially to include speaker turns since I din't have other dataset source. But you are correct about the randomness, its too random.

Following are the versions:

Python: Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 17:14:51)
Numpy: 1.15.4
Torch: 0.4.1
wq2012 commented 5 years ago

@dalonlobo Thanks for the information. The versions of your libraries are similar to what we had been using, so we do not expect any significant difference (other than things like random seeds).

The issue from gaodihe (https://github.com/google/uis-rnn/issues/14) is different from yours, and I have responded in that thread.

About your issue, first, constructing sequences by purely stitching single-speaker utterances won't let you learn much, as I said before.

However, even if that's true, we still shouldn't be seeing your results - new cluster id for each input observation. For the integration_test.py, we were also doing stitching of fake data, and observed accurate outputs.

So my guesses of what could have caused your issue:

77281900000 commented 5 years ago

@wq2012 Actually I have the same problem.But I think you may misunderstanding what the problem is.This problem is model.save()seems not save the model parameters correctly. What I do is that when I use model.fit and then test ,the model works well.But when I use model.load load the model just train and test the same data,the result become so strange like @dalonlobo (a random sequence)

wq2012 commented 5 years ago

@gaodihe

OK, thanks for the explanation. Unfortunately I could not replicate this as well...

Do you have steps to replicate the issue? Or can you confirm this is what you did?

  1. Run demo.py without changing it.

  2. Run demo.py again by changing:

  model = uisrnn.UISRNN(model_args)

  # training
  model.fit(train_sequence, train_cluster_id, training_args)
  model.save(SAVED_MODEL_NAME)
  # we can also skip training by calling:
  # model.load(SAVED_MODEL_NAME)

to

  model = uisrnn.UISRNN(model_args)

  # training
  # model.fit(train_sequence, train_cluster_id, training_args)
  # model.save(SAVED_MODEL_NAME)
  # we can also skip training by calling:
  model.load(SAVED_MODEL_NAME)
wq2012 commented 5 years ago

@dalonlobo @gaodihe

I think I know what problem it is.

I just committed a fix: https://github.com/google/uis-rnn/commit/a619126ce64b3209f6d3d22cd1ca4619a537b5db

Could you verify if this fix the issue?

Also, thanks so much for helping us catching this bug!!!

dalonlobo commented 5 years ago

Thanks, @wq2012 for fixing it and @gaodihe for highlighting the problem clearly. Unfortunately, I won't be able to test it for until next week as I'm on vacation. You can close this issue, I will comment once I test it.

77281900000 commented 5 years ago

@wq2012 I have tested it and the bug has been fixed correctly.Thanks for your efforts.