yl4579 / StarGANv2-VC

StarGANv2-VC: A Diverse, Unsupervised, Non-parallel Framework for Natural-Sounding Voice Conversion
MIT License
486 stars 108 forks source link

Support needed for Multiple Discriminators Implementation #68

Open MuruganR96 opened 1 year ago

MuruganR96 commented 1 year ago
    Using multiple discriminators is effective, and when the model converges, the sound quality on the unseen speaker is better, and the similarity to the target speaker is better than the original one.

Originally posted by @980202006 in https://github.com/yl4579/StarGANv2-VC/issues/6#issuecomment-945602695

MuruganR96 commented 1 year ago

Hi @yl4579 @980202006 I read out the https://github.com/yl4579/StarGANv2-VC/issues/6#issuecomment-945602695

So could you please guide me, on where to modify the code to implement the Multiple Discriminator Feature?

Thanks

MuruganR96 commented 1 year ago

@yl4579 @980202006 Please validate my hypothesis about Multiple Discriminator Implementation.

Ex. Batch size 1, I have 60 Speakers, and I am trying 3 Discriminators. So discriminator_ids 1, 2, 3

  1. In meldataset.py - Target speaker_id based, I will pass discriminator_id in MelDataset _getitem ([0-19] - 1, [20-39] -2, [40-59] - 3)
  2. In models.py -build a set of discriminators each of which only works on a subset of speakers https://github.com/yl4579/StarGANv2-VC/issues/6#issuecomment-912844250
def build_model(args, F0_model, ASR_model):
    generator = Generator(args.dim_in, args.style_dim, args.max_conv_dim, w_hpf=args.w_hpf, F0_channel=args.F0_channel)
    mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains, hidden_dim=args.max_conv_dim)
    style_encoder = StyleEncoder(args.dim_in, args.style_dim, args.num_domains, args.max_conv_dim)

    discriminator1 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
    discriminator2 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
    discriminator3 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)

    generator_ema = copy.deepcopy(generator)
    mapping_network_ema = copy.deepcopy(mapping_network)
    style_encoder_ema = copy.deepcopy(style_encoder)

    nets = Munch(generator=generator,
                 mapping_network=mapping_network,
                 style_encoder=style_encoder,
                 discriminator1=discriminator1,
                 discriminator2=discriminator2,
                 discriminator3=discriminator3,
                 f0_model=F0_model,
                 asr_model=ASR_model)

    nets_ema = Munch(generator=generator_ema,
                     mapping_network=mapping_network_ema,
                     style_encoder=style_encoder_ema)

    return nets, nets_ema
  1. in trainer.py - passing the discriminator id to compute_d_loss
    def _train_epoch(self):
        self.epochs += 1

        train_losses = defaultdict(list)
        _ = [self.model[k].train() for k in self.model]
        scaler = torch.cuda.amp.GradScaler() if (('cuda' in str(self.device)) and self.fp16_run) else None

        use_con_reg = (self.epochs >= self.args.con_reg_epoch)
        use_adv_cls = (self.epochs >= self.args.adv_cls_epoch)

        for train_steps_per_epoch, batch in enumerate(tqdm(self.train_dataloader, desc="[train]"), 1):

            ### load data
            batch = [b.to(self.device) for b in batch]
            x_real, y_org, x_ref, x_ref2, y_trg, z_trg, z_trg2, discriminator_id = batch

            # train the discriminator (by random reference)
            self.optimizer.zero_grad()
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                scaler.scale(d_loss).backward()
            else:
                d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                d_loss.backward()
            self.optimizer.step('discriminator', scaler=scaler)

            # train the discriminator (by target reference)
            self.optimizer.zero_grad()
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                scaler.scale(d_loss).backward()
            else:
                d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
                d_loss.backward()
  1. losses.py - based on discriminator id, predict and compute_d_loss
def compute_d_loss(nets, args, x_real, y_org, y_trg, discriminator_id, z_trg=None, x_ref=None, use_r1_reg=True, use_adv_cls=False, use_con_reg=False):
    args = Munch(args)

    assert (z_trg is None) != (x_ref is None)
    # with real audios
    x_real.requires_grad_()

    if discriminator_id == 1:
        out = nets.discriminator1(x_real, y_org)
    elif discriminator_id == 2:
        out = nets.discriminator2(x_real, y_org)
    else:
        out = nets.discriminator3(x_real, y_org)

    loss_real = adv_loss(out, 1)

    # R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
    if use_r1_reg:
        loss_reg = r1_reg(out, x_real)
    else:
        loss_reg = torch.FloatTensor([0]).to(x_real.device)

    # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
    loss_con_reg = torch.FloatTensor([0]).to(x_real.device)
    if use_con_reg:
        t = build_transforms()

        if discriminator_id == 1:
            out_aug = nets.discriminator1(t(x_real).detach(), y_org)
        elif discriminator_id == 2:
            out_aug = nets.discriminator2(t(x_real).detach(), y_org)
        else:
            out_aug = nets.discriminator3(t(x_real).detach(), y_org)

        loss_con_reg += F.smooth_l1_loss(out, out_aug)

    # with fake audios
    with torch.no_grad():
        if z_trg is not None:
            s_trg = nets.mapping_network(z_trg, y_trg)
        else:  # x_ref is not None
            s_trg = nets.style_encoder(x_ref, y_trg)

        F0 = nets.f0_model.get_feature_GAN(x_real)
        x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0)

    if discriminator_id == 1:
        out = nets.discriminator1(x_fake, y_trg)
    elif discriminator_id == 2:
        out = nets.discriminator2(x_fake, y_trg)
    else:
        out = nets.discriminator3(x_fake, y_trg)

    loss_fake = adv_loss(out, 0)
    if use_con_reg:

        if discriminator_id == 1:
            out_aug = nets.discriminator1(t(x_fake).detach(), y_trg)
        elif discriminator_id == 2:
            out_aug = nets.discriminator2(t(x_fake).detach(), y_trg)
        else:
            out_aug = nets.discriminator3(t(x_fake).detach(), y_trg)

        loss_con_reg += F.smooth_l1_loss(out, out_aug)

    # adversarial classifier loss
    if use_adv_cls:

        if discriminator_id == 1:
            out_de = nets.discriminator1.classifier(x_fake)
        elif discriminator_id == 2:
            out_de = nets.discriminator2.classifier(x_fake)
        else:
            out_de = nets.discriminator3.classifier(x_fake)

        loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_org[y_org != y_trg])

        if use_con_reg:

            if discriminator_id == 1:
                out_de_aug = nets.discriminator1.classifier(t(x_fake).detach())
            elif discriminator_id == 2:
                out_de_aug = nets.discriminator2.classifier(t(x_fake).detach())
            else:
                out_de_aug = nets.discriminator3.classifier(t(x_fake).detach())

            loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
    else:
        loss_real_adv_cls = torch.zeros(1).mean()

    loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \
            args.lambda_adv_cls * loss_real_adv_cls + \
            args.lambda_con_reg * loss_con_reg 

    return loss, Munch(real=loss_real.item(),
                       fake=loss_fake.item(),
                       reg=loss_reg.item(),
                       real_adv_cls=loss_real_adv_cls.item(),
                       con_reg=loss_con_reg.item())

@yl4579 @980202006

Thanks