pyannote / pyannote-audio

Neural building blocks for speaker diarization: speech activity detection, speaker change detection, overlapped speech detection, speaker embedding
http://pyannote.github.io
MIT License
5.86k stars 751 forks source link

Training PyanNet/SSeRiouSS on multiple GPUs not working #1666

Closed Jamiroquai88 closed 5 months ago

Jamiroquai88 commented 6 months ago

Tested versions

3.1.1

System information

ubuntu 20.04, 2xGPU A100

Issue description

Hello Hervé,

I am having issues with multi-GPU training that I am not sure how to solve. I would appreciate some feedback.

This is how I run the script on two GPUs:

CUDA_VISIBLE_DEVICES="3,4" python train3.0_powerset.py ...

To my Trainer part, I added:

strategy='ddp_find_unused_parameters_true'

and removed (based on the PyTorch lightning docs)

model = model.cuda()

But I am getting

Traceback (most recent call last):
  File "/shared/jprofant/Github/pyannote-audio/train3.0_powerset.py", line 141, in <module>
    trainer.fit(model, ckpt_path=args.init_model)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 948, in _run
    call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 94, in _call_setup_hook
    _call_lightning_module_hook(trainer, "setup", stage=fn)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/shared/jprofant/Github/pyannote-audio/pyannote/audio/core/model.py", line 274, in setup
    _ = self.example_output
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'PyanNet' object has no attribute 'example_output'

Which seems true (but Model has example_output)

Very similar thing happens with SSeRiouSS model, but with that one I am getting

Traceback (most recent call last):
  File "/shared/jprofant/Github/pyannote-audio/train3.0_powerset.py", line 141, in <module>
    trainer.fit(model, ckpt_path=args.init_model)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 948, in _run
    call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 94, in _call_setup_hook
    _call_lightning_module_hook(trainer, "setup", stage=fn)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/shared/jprofant/Github/pyannote-audio/pyannote/audio/core/model.py", line 274, in setup
    _ = self.example_output
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/functools.py", line 993, in __get__
    val = self.func(instance)
  File "/shared/jprofant/Github/pyannote-audio/pyannote/audio/core/model.py", line 195, in example_output
    example_output = self(example_input_array)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/shared/jprofant/Github/pyannote-audio/pyannote/audio/models/segmentation/SSeRiouSS.py", line 300, in forward
    outputs, _ = self.wav2vec.extract_features(
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torchaudio/models/wav2vec2/model.py", line 83, in extract_features
    x, lengths = self.feature_extractor(waveforms, lengths)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 141, in forward
    x, length = layer(x, length)  # (batch, feature, frame)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 90, in forward
    x = self.conv(x)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 313, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/ubuntu/miniconda3/envs/pyannote/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 309, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

I tried to move some tensors to GPU via .to(self.device) - but it is not advised to use in PyTorch lightning. They recommend .as_type(x), which unfortunately always leads to the same example_output error.

When using a single GPU

CUDA_VISIBLE_DEVICES="3" python train3.0_powerset.py ...

and keeping model = model.cuda() it works fine for SSeRiouSS but not for PyanNet.

I am willing to contribute to this, I would just need some pointers first. Thank you!

Minimal reproduction example (MRE)

can't share my data, sorry

hbredin commented 6 months ago

Would you mind sharing train3.0_powerset.py?

Jamiroquai88 commented 6 months ago
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2022
# Author: Jan Profant <jan.profant@rev.com>
# All Rights Reserved
import argparse
import sys

import torch
from pytorch_lightning.cli import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR, LambdaLR

from pyannote.audio import Pipeline, Model
from pyannote.database import FileFinder, registry

from pyannote.audio.models.segmentation import PyanNet, SSeRiouSS
from pyannote.audio.tasks import Segmentation
from types import MethodType
from torch.optim import Adam
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    RichProgressBar,
)
from pytorch_lightning.loggers import WandbLogger

torch.set_float32_matmul_precision('high')

def configure_optimizers(self):
    optimizer = Adam(self.parameters(), lr=lr)
    scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.7 ** epoch if epoch < 5 else 0.75 ** epoch, verbose=True)
    # scheduler = StepLR(optimizer, step_size=, gamma=0.5)
    return [optimizer], [scheduler]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--database', required=True,
                        help='path to database.yaml file')
    parser.add_argument('--wandb-project', type=str, required=True,
                        help='wandb project name for the logging')
    parser.add_argument('--use-pretrained', default=False, action='store_true',
                        help='load pretrained model from huggingface hub')
    parser.add_argument('--init-model', required=False, help='path to model to initialize the NN')
    parser.add_argument('--workers', type=int, default=8)

    # model options
    parser.add_argument('--chunk-dur', default=10.0, type=float,
                        help='LSTM chunk duration')
    parser.add_argument('--lr', default=1e-3, type=float,
                        help='Adam learning rate')
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--wavlm-type', type=str, default='WAVLM_BASE',
                        choices=['WAVLM_BASE', 'WAVLM_BASE_PLUS', 'WAVLM_LARGE'])

    # LSTM options
    parser.add_argument('--lstm-type', choices=['pyannet', 'sseriouss'], default='pyannet')
    parser.add_argument('--lstm-hidden-size', type=int, default=256)
    parser.add_argument('--lstm-dropout', type=float, default=0.0)
    parser.add_argument('--lstm-num-layers', type=int, default=2)

    args = parser.parse_args()

    registry.load_database(args.database)
    dataset = registry.get_protocol('audiodb.SpeakerDiarization.train_protocol',
                                    preprocessors={'audio': FileFinder()})

    task = Segmentation(
        dataset,
        duration=args.chunk_dur,
        max_speakers_per_chunk=3,
        max_speakers_per_frame=2,
        batch_size=args.batch_size,
        num_workers=args.workers)

    if args.use_pretrained and args.init_model:
        print(f'Can\'t load a pretrained model and at the same time initialize, fix your arguments.')
        sys.exit(1)
    elif args.use_pretrained:
        model = Model.from_pretrained(
            "pyannote/segmentation-3.0",
            use_auth_token="XXXXX")
    elif args.init_model:
        model = Model.from_pretrained(args.init_model)
    else:
        LSTM = {
            "hidden_size": args.lstm_hidden_size,
            "num_layers": args.lstm_num_layers,
            "bidirectional": True,
            "monolithic": True,
            "dropout": args.lstm_dropout,
        }
        if args.lstm_type == 'pyannet':
            model = PyanNet(task=task, lstm=LSTM)
        elif args.lstm_type == 'sseriouss':
            model = SSeRiouSS(wav2vec=args.wavlm_type, task=task, lstm=LSTM)
        else:
            raise NotImplementedError("Unsupported LSTM type")

    lr = args.lr
    model.configure_optimizers = MethodType(configure_optimizers, model)
    model.task = task

    checkpoint = ModelCheckpoint(
        every_n_epochs=1,
        save_last=True,
        save_weights_only=False,
        filename="{epoch}",
        verbose=True
    )
    callbacks = [RichProgressBar(), checkpoint]

    # we train for at most 20 epochs
    from pytorch_lightning import Trainer

    wandb_logger = WandbLogger(log_model="all", project=args.wandb_project)

    trainer = Trainer(accelerator="gpu",
                      callbacks=callbacks,
                      max_epochs=args.epochs,
                      gradient_clip_val=0.5,
                      num_sanity_val_steps=10,
                      logger=wandb_logger,
                      strategy='ddp_find_unused_parameters_true',
                      )

    trainer.fit(model, ckpt_path=args.init_model)
hbredin commented 6 months ago

I think this problem has been fixed in develop branch since https://github.com/pyannote/pyannote-audio/commit/c0b9e79aa8063c7ddc78e7213799d0aeae9d3d10

Can you try with the latest commit?

Jamiroquai88 commented 5 months ago

That was it, thank you. Closing now.