juanmc2005 / diart

A python package to build AI-powered real-time audio applications
https://diart.readthedocs.io
MIT License
1.1k stars 90 forks source link

Add compatibility with pyannote 3.0 embedding wrappers #188

Closed sorgfresser closed 1 year ago

sorgfresser commented 1 year ago

Adds initial support for the embedding model using in pyannote/speaker-diarization-3.0

Usage:

embedding = EmbeddingModel.from_pyannote("hbredin/wespeaker-voxceleb-resnet34-LM")
embedding.to(device("cuda"))
config = SpeakerDiarizationConfig(embedding=embedding)
pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource()
inference = StreamingInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter(mic.uri, "output/file.rttm"))
prediction = inference()

I am still lacking support for pyannote/segmentation-3.0 as of now and I am not 100% sure why... I thought it should be drop in replacement for pyannote/segmentation but it does not seem to work.

Any hints here would be greatly appreciated.

hbredin commented 1 year ago

Hint for pyannote/segmentation-3.0 support: use Powerset.to_multilabel conversion as illustrated here

juanmc2005 commented 1 year ago

@sorgfresser if that's ok with you, let's open a different PR for segmentation-3.0 so we can merge this one before that

sorgfresser commented 1 year ago

Hey @juanmc2005 Thanks for your recommendations, I hopefully added them all now. Could you review again? The usage has changed a bit, I've updated my original comment accordingly.

juanmc2005 commented 1 year ago

@sorgfresser I just released version 0.8. Please make sure to rebase your branch against develop so we're able to merge:

git checkout pyannote-3.0
git rebase <diart remote>/develop
# Once successful and without conflicts
git push --force origin pyannote-3.0
juanmc2005 commented 1 year ago

Ok I think the code is pretty solid! I want to test this but I would prefer to wait for the pyannote fix to be merged and hopefully released by @hbredin 🙏🏻

If the pyannote fix takes a long time to get into a release, I would prefer to do the required changes here anyway. In this case, WeSpeaker embeddings wouldn't work (on GPU) temporarily, but I prefer the code to be clean if it's going to be part of the next release

juanmc2005 commented 1 year ago

@sorgfresser could you please rebase on top of develop and force push again? I added some GitHub action checks that I would like to run on this PR to make sure nothing's broken

sorgfresser commented 1 year ago

Thanks for rebasing @juanmc2005 and sorry for the late reply. I added the docstrings and removed the cpu moving - is there anything else I can add / modify?

juanmc2005 commented 1 year ago

Hey @sorgfresser thanks for the new changes. I think we're good here. Now that the code is looking good I'll pull the branch locally and do some tests to see if the feature is working correctly. In particular wespeaker and some other model like ecapa tdnn.

If my tests look good I'll go ahead and merge. I'll probably wait for the pyannote fix to release v0.9 though.

juanmc2005 commented 1 year ago

@sorgfresser quick update after some tests.

It looks like normalizing weights does affect the embeddings. DER on AMI goes from 27.3 to to 29.8, which is pretty bad. When I remove the normalization it goes down to 27.5, so there's something else affecting performance negatively.

I think we should add a parameter somewhere to specify if weights should be normalized.

I'll keep investigating and get back

juanmc2005 commented 1 year ago

@hbredin looks like the difference between the 27.5 and 27.3 was because of pytorch 2.1.0. Downgrading to 2.0.1 solves this. Have you observed any performance changes in pyannote too with this new pytorch version?

Maybe it's related to the automatic conversion of the segmentation and embedding models. I keep getting these warnings:

Model was trained with pyannote.audio 0.0.1, yours is 3.0.0. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.0.1+cu117. Bad things might happen unless you revert torch to 1.x.

I don't think this is a deal-breaker so I won't change the requirements to force torch<=2.1.0, but it may be worth checking if the model is being loaded badly with torch 2.1.0

juanmc2005 commented 1 year ago

@sorgfresser could you move the weight normalization code to diart.blocks.OverlappedSpeechPenalty? The new OverlappedSpeechPenalty.__call__() method should look like this:

    def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
        weights = self.formatter.cast(segmentation)  # shape (batch, frames, speakers)
        with torch.no_grad():
            probs = torch.softmax(self.beta * weights, dim=-1)
            weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
            weights[weights < 1e-8] = 1e-8
            if self.normalize:
                min_values = weights.min(dim=1, keepdim=True).values
                max_values = weights.max(dim=1, keepdim=True).values
                weights = (weights - min_values) / (max_values - min_values)
                weights.nan_to_num_(1e-8)
        return self.formatter.restore_type(weights)

Where self.normalize is a new constructor argument that defaults to False. Also, we need to add a normalize_weights argument to OverlapAwareSpeakerEmbedding. Finally, let's add a normalize_embedding_weights to DiarizationConfig, and then pass this value to OverlapAwareSpeakerEmbedding inside the constructor of diart.blocks.diarization.SpeakerDiarization.

This way, users can decide whether they want to do this normalization or not directly in the pipeline config.

I could do these changes myself but I'm not sure I have write access to your fork.

I would also like to have this as a CLI argument --normalize-embedding-weights in diart.stream, diart.benchmark, diart.tune and diart.serve, but I would merge the PR without this anyway. I leave it for you to decide if you want to implement that feature here. In any case, pipeline configs are becoming pretty big so I'm thinking of converting them to a yaml file soon.

juanmc2005 commented 1 year ago

Otherwise, I was able to run diart.stream with WeSpeaker embeddings! Amazing work @sorgfresser!

hbredin commented 1 year ago

@hbredin looks like the difference between the 27.5 and 27.3 was because of pytorch 2.1.0. Downgrading to 2.0.1 solves this. Have you observed any performance changes in pyannote too with this new pytorch version?

pyannote's CI is kind of non-existent so I don't actually know :)

sorgfresser commented 1 year ago

I added the boolean to cli and moved the normalization to OverlappedSpeechPenalty. Is that the way you'd like the cli to behave? Would be nice if you'd test it again. What models did you use for your benchmark to get the 29.8 on AMI? I think btw you can edit the fork, feel free to do so but I would be willing to implement any other changes too if you prefer me to do it.

juanmc2005 commented 1 year ago

@sorgfresser thank you for the swift reply and commit! I ran the command on the reproducibility section, only changing --tau-active, --delta-new and --rho-update to the AMI values from the hyper-parameter table. So the models I used were pyannote/segmentation@Interspeech2021 and pyannote/embedding.

I'll re-run the tests as soon as I can and get back with updates

juanmc2005 commented 1 year ago

Update: AMI benchmark with WeSpeaker embeddings and not weight normalization gives DER=30.8

juanmc2005 commented 1 year ago

@sorgfresser huge thanks for this feature! Stay tuned for v0.9! I hope we can get it released as soon as possible. If you liked contributing to diart I'd love to work with you on other issues in need of work 😃 Even better if it's on the list for the v0.9 milestone!