Open MuruganR96 opened 1 year ago
Hi @yl4579 @980202006 I read out the https://github.com/yl4579/StarGANv2-VC/issues/6#issuecomment-945602695
So could you please guide me, on where to modify the code to implement the Multiple Discriminator Feature?
Thanks
@yl4579 @980202006 Please validate my hypothesis about Multiple Discriminator Implementation.
Ex. Batch size 1, I have 60 Speakers, and I am trying 3 Discriminators. So discriminator_ids 1, 2, 3
build a set of discriminators each of which only works on a subset of speakers
https://github.com/yl4579/StarGANv2-VC/issues/6#issuecomment-912844250def build_model(args, F0_model, ASR_model):
generator = Generator(args.dim_in, args.style_dim, args.max_conv_dim, w_hpf=args.w_hpf, F0_channel=args.F0_channel)
mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains, hidden_dim=args.max_conv_dim)
style_encoder = StyleEncoder(args.dim_in, args.style_dim, args.num_domains, args.max_conv_dim)
discriminator1 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
discriminator2 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
discriminator3 = Discriminator(args.dim_in, args.num_domains, args.max_conv_dim, args.n_repeat)
generator_ema = copy.deepcopy(generator)
mapping_network_ema = copy.deepcopy(mapping_network)
style_encoder_ema = copy.deepcopy(style_encoder)
nets = Munch(generator=generator,
mapping_network=mapping_network,
style_encoder=style_encoder,
discriminator1=discriminator1,
discriminator2=discriminator2,
discriminator3=discriminator3,
f0_model=F0_model,
asr_model=ASR_model)
nets_ema = Munch(generator=generator_ema,
mapping_network=mapping_network_ema,
style_encoder=style_encoder_ema)
return nets, nets_ema
compute_d_loss
def _train_epoch(self):
self.epochs += 1
train_losses = defaultdict(list)
_ = [self.model[k].train() for k in self.model]
scaler = torch.cuda.amp.GradScaler() if (('cuda' in str(self.device)) and self.fp16_run) else None
use_con_reg = (self.epochs >= self.args.con_reg_epoch)
use_adv_cls = (self.epochs >= self.args.adv_cls_epoch)
for train_steps_per_epoch, batch in enumerate(tqdm(self.train_dataloader, desc="[train]"), 1):
### load data
batch = [b.to(self.device) for b in batch]
x_real, y_org, x_ref, x_ref2, y_trg, z_trg, z_trg2, discriminator_id = batch
# train the discriminator (by random reference)
self.optimizer.zero_grad()
if scaler is not None:
with torch.cuda.amp.autocast():
d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
scaler.scale(d_loss).backward()
else:
d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
d_loss.backward()
self.optimizer.step('discriminator', scaler=scaler)
# train the discriminator (by target reference)
self.optimizer.zero_grad()
if scaler is not None:
with torch.cuda.amp.autocast():
d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
scaler.scale(d_loss).backward()
else:
d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, discriminator_id, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg)
d_loss.backward()
def compute_d_loss(nets, args, x_real, y_org, y_trg, discriminator_id, z_trg=None, x_ref=None, use_r1_reg=True, use_adv_cls=False, use_con_reg=False):
args = Munch(args)
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.requires_grad_()
if discriminator_id == 1:
out = nets.discriminator1(x_real, y_org)
elif discriminator_id == 2:
out = nets.discriminator2(x_real, y_org)
else:
out = nets.discriminator3(x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
loss_reg = torch.FloatTensor([0]).to(x_real.device)
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = torch.FloatTensor([0]).to(x_real.device)
if use_con_reg:
t = build_transforms()
if discriminator_id == 1:
out_aug = nets.discriminator1(t(x_real).detach(), y_org)
elif discriminator_id == 2:
out_aug = nets.discriminator2(t(x_real).detach(), y_org)
else:
out_aug = nets.discriminator3(t(x_real).detach(), y_org)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# with fake audios
with torch.no_grad():
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else: # x_ref is not None
s_trg = nets.style_encoder(x_ref, y_trg)
F0 = nets.f0_model.get_feature_GAN(x_real)
x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0)
if discriminator_id == 1:
out = nets.discriminator1(x_fake, y_trg)
elif discriminator_id == 2:
out = nets.discriminator2(x_fake, y_trg)
else:
out = nets.discriminator3(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
if use_con_reg:
if discriminator_id == 1:
out_aug = nets.discriminator1(t(x_fake).detach(), y_trg)
elif discriminator_id == 2:
out_aug = nets.discriminator2(t(x_fake).detach(), y_trg)
else:
out_aug = nets.discriminator3(t(x_fake).detach(), y_trg)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# adversarial classifier loss
if use_adv_cls:
if discriminator_id == 1:
out_de = nets.discriminator1.classifier(x_fake)
elif discriminator_id == 2:
out_de = nets.discriminator2.classifier(x_fake)
else:
out_de = nets.discriminator3.classifier(x_fake)
loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_org[y_org != y_trg])
if use_con_reg:
if discriminator_id == 1:
out_de_aug = nets.discriminator1.classifier(t(x_fake).detach())
elif discriminator_id == 2:
out_de_aug = nets.discriminator2.classifier(t(x_fake).detach())
else:
out_de_aug = nets.discriminator3.classifier(t(x_fake).detach())
loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
else:
loss_real_adv_cls = torch.zeros(1).mean()
loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \
args.lambda_adv_cls * loss_real_adv_cls + \
args.lambda_con_reg * loss_con_reg
return loss, Munch(real=loss_real.item(),
fake=loss_fake.item(),
reg=loss_reg.item(),
real_adv_cls=loss_real_adv_cls.item(),
con_reg=loss_con_reg.item())
@yl4579 @980202006
Thanks
Originally posted by @980202006 in https://github.com/yl4579/StarGANv2-VC/issues/6#issuecomment-945602695