auspicious3000 / SpeechSplit

Unsupervised Speech Decomposition Via Triple Information Bottleneck
http://arxiv.org/abs/2004.11284
MIT License
636 stars 92 forks source link

Training G and P at different sample rates #48

Closed Merlin-721 closed 3 years ago

Merlin-721 commented 3 years ago

I am attempting to retrain at 22050Hz. At this SR validation loss for G and P do not decrease (P actually steadily increases). I am using test samples from every speaker in train set. Both loss_id's decrease as expected.

I train G according to this code in solver.py:

self.G = self.G.train()
# G Identity mapping loss
x_f0 = torch.cat((x_real_org, f0_org), dim=-1)
x_f0_intrp = self.Interp(x_f0, len_org) 

f0_org_intrp = quantize_f0_torch(x_f0_intrp[:,:,-1])[0]
x_f0_intrp_org = torch.cat((x_f0_intrp[:,:,:-1], f0_org_intrp), dim=-1)

# G forward
x_pred = self.G(x_f0_intrp_org, x_real_org, emb_org)
g_loss_id = F.mse_loss(x_pred, x_real_org, reduction='mean') 

# Backward and optimize.
self.g_optimizer.zero_grad()
g_loss_id.backward()
self.g_optimizer.step()

loss['G/loss_id'] = g_loss_id.item()

and train P according to this code:

self.P = self.P.train()
# Preprocess f0_trg for P 
x_f0_trg = torch.cat((x_real_org, f0_org), dim=-1)
x_f0_intrp_trg = self.Interp(x_f0_trg, len_org) 
# Target for P
f0_trg_intrp = quantize_f0_torch(x_f0_intrp_trg[:,:,-1])[0]
f0_trg_intrp_indx = f0_trg_intrp.argmax(2)

# P forward
f0_pred = self.P(x_real_org,f0_trg_intrp)
p_loss_id = F.cross_entropy(f0_pred.transpose(1,2),f0_trg_intrp_indx, reduction='mean')

self.p_optimizer.zero_grad()
p_loss_id.backward()
self.p_optimizer.step()
loss['P/loss_id'] = p_loss_id.item()

I feel this may be due to the LSTMs in the encoders and decoders, since at a different SR the vocal features appear over a different scale, however any other suggestions would be appreciated.

Merlin-721 commented 3 years ago

image image

I am training with a batch size of 64 (hence seemingly low step count) and LRs of 0.0005.

auspicious3000 commented 3 years ago

This looks like overfitting. You could check the training data size, the validation code, etc.

Merlin-721 commented 3 years ago

Line 117 in data_loader.py:

a = np.clip(a, 0, 1)

sets all the spectrograms to 0, since my preprocessing uses mostly negative values... that took so long to find!