Closed paopaoyaya closed 10 months ago
Add another question: The FAD is calculated using Hubert features, why not use mel features or other models, like vggish and so on. Maybe this is not influent, right?
And another question: the result of using Hubert is better than using mel, but using mel in HifiGAN can generate speech that is closer to the source. So I want to using mel as ASGAN‘s feature. Then back to the first question: I can’t train to get the good results, neither Hubert nor mel. I'm looking forward to you reply! Thanks!
Hi @paopaoyaya , thank you for the questions. I'll try my best to answer each one. The first two are related, so I'll answer them together:
My parameter settings remain the same as those provided by you, but the training results obtained are very different from those in the paper. The FAD of training is always not less than 1. I would like to ask where the problem might be? Add another question: The FAD is calculated using Hubert features, why not use mel features or other models, like vggish and so on. Maybe this is not influent, right?
The FAD we compute not using Hubert features, but with features from the ResNeXT classifier features from here. We tried to indicate this in the paper with the phrase "Using either the classification output or 1024-dimensional features extracted from the penultimate layer in the classifier, we [compute the FAD]". We opt to use the ResNeXT features for FID/FAD because it is the same method used by the baselines we compare to (Sashimi / DiffWave), so that our results are the most comparable to existing works. I think this problem of getting different numbers may be resolved if you use the ResNeXT classifier model to extract features to compute the FAD and other latent feature metrics, since you are very right that using different feature sets for the FAD computation will result in very different numbers.
And another question: the result of using Hubert is better than using mel, but using mel in HifiGAN can generate speech that is closer to the source. So I want to using mel as ASGAN‘s feature. Then back to the first question: I can’t train to get the good results, neither Hubert nor mel. I'm looking forward to you reply! Thanks!
I'm a bit unsure what is meant with the question here. HifiGAN can generate good speech using the mel-features, but from my experiments I found that using HifiGAN to vocode HuBERT features also has sufficiently good reconstruction ability. The main reason for using HuBERT features, however, is not optimal vocoding performance, but rather its disentanglement: the HuBERT feature space is already decently disentanled (read: linearly predictive of many aspects of speech, as shown in the SUPERB benchmark), much more so than mel feature banks. This means that the GAN has an easier job learning to map the latent vector w
to the output HuBERT features, since the HuBERT features are more disentangled than raw mel spectrograms.
I'm also a bit unsure -- are the results you're obtaining bad as in sound poor, or do they sound good and the metric numbers are just worse than you expected but the output still sounds ok? If it is the latter, then using the ResNeXT classifier when computing metrics should solve the problem.
Thanks again for your interest in our work, I hope these answers help a bit! -- RF5
Thanks for your reply. In train_asgan.py, the fad_features is calculated as the way in below image. if data_type is hubert, The first parameter c_fake passed into the fad function is hubert feature. And if data_type is melspec, the first parameter c_fake passed into the fad function is obtained by converting the mel features generated by the generator into hubert features. I can't find where you to extract features from the ResNeXT classifier.
Hi @paopaoyaya , thank you for the questions. I'll try my best to answer each one. The first two are related, so I'll answer them together:
My parameter settings remain the same as those provided by you, but the training results obtained are very different from those in the paper. The FAD of training is always not less than 1. I would like to ask where the problem might be? Add another question: The FAD is calculated using Hubert features, why not use mel features or other models, like vggish and so on. Maybe this is not influent, right?
The FAD we compute not using Hubert features, but with features from the ResNeXT classifier features from here. We tried to indicate this in the paper with the phrase "Using either the classification output or 1024-dimensional features extracted from the penultimate layer in the classifier, we [compute the FAD]". We opt to use the ResNeXT features for FID/FAD because it is the same method used by the baselines we compare to (Sashimi / DiffWave), so that our results are the most comparable to existing works. I think this problem of getting different numbers may be resolved if you use the ResNeXT classifier model to extract features to compute the FAD and other latent feature metrics, since you are very right that using different feature sets for the FAD computation will result in very different numbers.
And another question: the result of using Hubert is better than using mel, but using mel in HifiGAN can generate speech that is closer to the source. So I want to using mel as ASGAN‘s feature. Then back to the first question: I can’t train to get the good results, neither Hubert nor mel. I'm looking forward to you reply! Thanks!
I'm a bit unsure what is meant with the question here. HifiGAN can generate good speech using the mel-features, but from my experiments I found that using HifiGAN to vocode HuBERT features also has sufficiently good reconstruction ability. The main reason for using HuBERT features, however, is not optimal vocoding performance, but rather its disentanglement: the HuBERT feature space is already decently disentanled (read: linearly predictive of many aspects of speech, as shown in the SUPERB benchmark), much more so than mel feature banks. This means that the GAN has an easier job learning to map the latent vector
w
to the output HuBERT features, since the HuBERT features are more disentangled than raw mel spectrograms.I'm also a bit unsure -- are the results you're obtaining bad as in sound poor, or do they sound good and the metric numbers are just worse than you expected but the output still sounds ok? If it is the latter, then using the ResNeXT classifier when computing metrics should solve the problem.
Thanks again for your interest in our work, I hope these answers help a bit! -- RF5
And about the result of HifiGAN, I mean, converting the source speech to HuBERT features and mel features respectively, and using HifiGAN to generate fake speech based on them respectively, the fake speech generated based on mel features is closer to the source speech than based on HuBERT feature. Of course, no matter which feature to use, it sounds great, just a different degree of similarity to the source speech.
I see I see, yep the mel-spectrogram (since it is closer to actual audio than HuBERT features) will sound closer to the original. Similarly if you use features from earlier layers in HuBERT, it will sound closer to the original audio than later HuBERT features. The tradeoff, however, is that the earlier/closer-to-audio features you use, the less disentangled it is, so the harder task it is for the GAN to learn a good disentangled representation.
For the training script that FAD there is a proxy for the true full metric FAD, since we don't want to spend additional time performing ResNeXT inference over the entire validation output samples every validation cycle. I.e. the FAD computed using HuBERT in the training loop is just an approximate indication, not the actual FAD and other metrics computed for final test evaluation. For example, here is a snippet of the final eval script I used to compute test set FAD scores:
# getting model output and classifier features for generated output
out = model.unconditional_generate(cfg.bs)
gen_wavs[running_n:running_n+out.shape[0]] = out.cpu()
outlen = torch.ones((out.shape[0]), dtype=torch.long, device=cfg.device)*cfg.seq_len
out = out.to(cfg.device)
logit, feat = classifier(out, outlen, return_features=True) # (bs, n_classes)
logits[running_n:running_n+logit.shape[0]] = logit.cpu()
feats_gen[running_n:running_n+feat.shape[0]] = feat.cpu()
running_n += out.shape[0]
...
# computing FID
gen_fid = fid(feats_gen, feats_train)
logging.info(f"FID [{cfg.model}]: {gen_fid:6.5f}")
metrics['FID'] = gen_fid
test_fid = fid(feats_test, feats_train)
logging.info(f"FID [test set]: {test_fid:6.5f}")
metrics['FID_test'] = test_fid
train_fid = fid(feats_train, feats_train)
logging.info(f"FID [train set]: {train_fid:6.5f}")
metrics['FID_train'] = train_fid
TL;DR: metrics computed in train script are rough, fast approximations of the final test set metrics, and a separate eval script should be used for computing the test set metrics using the method described in the paper. Hope that helps explain the discrepancies. I'll add a note about this in the readme, since it is not so clear.
Thanks a lot! I would like to ask how many epochs and steps the pre-trained model you provided was trained. And what is the best approximate version of FAD during training? Because I think you judged whether to update best model, the file .pt, by approximating FAD.
Hi, sure thing. As in the paper, I trained for 520k steps. I am not sure how many epochs that corresponds to, but the full training config with batch size / grad clip / all other details are in the cfg_yaml
key of the provided checkpoint, so the batch size and total steps (in the checkpoint name) can be converted to number of epochs. The approximate FAD given by the training script at this point is around 1.5 for the mel model and around 10 for the HuBERT model. Hope that helps a bit!
Okay, thank you very much!
Hi, sure thing. As in the paper, I trained for 520k steps. I am not sure how many epochs that corresponds to, but the full training config with batch size / grad clip / all other details are in the
cfg_yaml
key of the provided checkpoint, so the batch size and total steps (in the checkpoint name) can be converted to number of epochs. The approximate FAD given by the training script at this point is around 1.5 for the mel model and around 10 for the HuBERT model. Hope that helps a bit!
Sorry to bother you again. I also want to know the hyperparameters used in training mel model, like batch size, betas, lr and so on. Is it the same as training hubert? Because I train the mel model and the best FAD obtained is 7. I don’t know if it’s a problem with parameter settings. And I found that using Adam optimizer, but lr is unchanged, is it right?
I think the LR and other settings are unchanged, but there might be some tiny differences. I think the training config is inside the mel checkpoint, also in a config key I suspect?
I think the LR and other settings are unchanged, but there might be some tiny differences. I think the training config is inside the mel checkpoint, also in a config key I suspect?
I changed layer_specs, seq_len and c_dim to fit mel model, and others were not changed. During the training process, the learning rate always remains unchanged, right?
I think the LR and other settings are unchanged, but there might be some tiny differences. I think the training config is inside the mel checkpoint, also in a config key I suspect?
I didn't find the mel checkpoint, have you published mel pre-trained model?
Yep, you can find it on the releases page, specifically, the g_02365000_package.pth
file.
When I tried to load the g_02365000_package.pth
file, the following error was encountered.
Ahh, yeah it is a torch package, not a torch state dict (hence the _package suffix). To see how to load torch packages, the torch docs -- you need to use the torch PackageImporter class to read it. Hope that helps!
I find the content in g_02365000_package.pth
is relative to hifigan, not the mel pre-trained model. I would like to ask if you have a pre-trained model for training ASGAN using mel features. Thanks!
Hi, I managed to load it fine after some fiddling, but you're right it does not have the training hyperparameters. Upon looking at the original code again, it is the density/config.py
file which has all the training parameters for both hubert and mel runs. By default the hubert lines are uncommented, and there are commented lines for the items which change for the mel-spectrogram based model. The learning rate is the same for both variants.
I hope that helps!
Hi, I managed to load it fine after some fiddling, but you're right it does not have the training hyperparameters. Upon looking at the original code again, it is the
density/config.py
file which has all the training parameters for both hubert and mel runs. By default the hubert lines are uncommented, and there are commented lines for the items which change for the mel-spectrogram based model. The learning rate is the same for both variants.I hope that helps!
Well, I switched the relevant parameters in the density/config.py
file from hubert line to mel line, I couldn't get good training results, the generated speech couldn't be recognized. Or I could try training again. Thanks!
Ok, I checked the SLURM logs for exact parameters for the mel training run. Here they are:
python train_rp_w.py model=rp_w train_root=/scratch-small-local/250362-hpc1-hpc/datasets/sc09/ n_valid=400 data_type=melspec checkpoint_path=./density/runs/conv2-sc09-mel/ z_dim=512 rp_w_cfg.z_dim=512 rp_w_cfg.w_layers=3 batch_size=32 lr=2e-3 grad_clip=10 aug_init_p=0.1 stdout_interval=100 validation_interval=2500 n_epochs=800 c_dim=128 rp_w_cfg.c_dim=128 d_lr_mult=0.1 fp16=True preload=False num_workers=12 betas=[0,0.99] rp_w_cfg.equalized_lr=True rp_w_cfg.use_sg3_ff=True rp_w_cfg.D_kernel_size=5 rp_w_cfg.D_block_repeats=[3,3,3,3,2] use_sc09_splits=True sc09_train_csv=splits/sc09-train.csv sc09_valid_csv=splits/sc09-valid.csv
Note these are slightly different from the default config file here, but rather it is the values specified in the paper. Perhaps using these values from the paper may help?
Ok, I checked the SLURM logs for exact parameters for the mel training run. Here they are:
python train_rp_w.py model=rp_w train_root=/scratch-small-local/250362-hpc1-hpc/datasets/sc09/ n_valid=400 data_type=melspec checkpoint_path=./density/runs/conv2-sc09-mel/ z_dim=512 rp_w_cfg.z_dim=512 rp_w_cfg.w_layers=3 batch_size=32 lr=2e-3 grad_clip=10 aug_init_p=0.1 stdout_interval=100 validation_interval=2500 n_epochs=800 c_dim=128 rp_w_cfg.c_dim=128 d_lr_mult=0.1 fp16=True preload=False num_workers=12 betas=[0,0.99] rp_w_cfg.equalized_lr=True rp_w_cfg.use_sg3_ff=True rp_w_cfg.D_kernel_size=5 rp_w_cfg.D_block_repeats=[3,3,3,3,2] use_sc09_splits=True sc09_train_csv=splits/sc09-train.csv sc09_valid_csv=splits/sc09-valid.csv
Note these are slightly different from the default config file here, but rather it is the values specified in the paper. Perhaps using these values from the paper may help?
Ok! Thanks!
Sorry to bother you again. I would like to ask what classifier is used in the logprobs_gen
used to calculate IS and mIS in the metrics.py
file in the main directory.
Hi, it is the one linked in the paper, i.e. https://github.com/RF5/simple-speech-commands . Hope that helps!
I calculated the value of IS on the test set to be just over 2 points, but I saw in the paper that the IS on the test set reached 9. Here is how I calculated it: I think it should be different from your calculation method, could you please tell me what the difference is?
Hi @paopaoyaya , what function are you using for inception_score
? If it is from the metrics.py file in this repo, please see the docstring for the inception_score
function:
Calculate inception score from `logprobs_gen` (bs, n_classes) of
log probabilities.
So, your code is calling the function with raw probabilities, not log probabilities as required. Please use the log probabilities with this function to obtain the correct value.
I used the function in the metrics.py file in the main directory for inception_score. Does it mean that the probability in the above picture needs to be log-operated and then passed into the inception_score function?
Yep yep, they must be log probabilities, not raw probability stores.
Yep yep, they must be log probabilities, not raw probability stores.
Okay, thanks!
I got the right inception_score
value, but when I passed the same log_probabilities
into modified_inception_score2
function, I got the value 3, but the right value for mIS in paper is around 240. I‘m very grateful for you reply!
the modified inception score function must be called with raw probabilities, not log probabilities. The docstring for it is copied badly, but you can see it uses scipy.stats.entropy
, which requires raw probabilities, not log probabilities. Please try compute it again with raw probabilities for the modified inception score. I have also updated the docstring for it.
the modified inception score function must be called with raw probabilities, not log probabilities. The docstring for it is copied badly, but you can see it uses
scipy.stats.entropy
, which requires raw probabilities, not log probabilities. Please try compute it again with raw probabilities for the modified inception score. I have also updated the docstring for it.
I used raw probabilities to calculate the modified inception score as shown in the picture below and got the value around 100000. Did I go wrong somewhere?
It is very likely there is an overflow error possibly occuring somewhere there. Can I recommend you call the function with double precision inputs? This is how I call it:
mis_macro_score = modified_inception_score2(logprobs_gen.double().exp().numpy())
It is very likely there is an overflow error possibly occuring somewhere there. Can I recommend you call the function with double precision inputs? This is how I call it:
mis_macro_score = modified_inception_score2(logprobs_gen.double().exp().numpy())
I called the function with double precision inputs, but still got the value around 100000. It doesn’t seem to be due to the precision of the input, but I can't find where the problem is.
I'm pretty sure it is a precision problem somewhere in your input, as otherwise that number does not make sense. The mIS code is fairly simple in that function, and the only way to achieve a score that high is if, on average, your average entropy between two subsets probs_gen_1
and probs_gen_2
is ~11.6 nats, which seams unreasonable for a 10-class probability vector. You would need some truly bizzare values to get a KL/relative entropy that high with a trained classifier modiel.
Can you try this perhaps: work in the log domain from the beginning and using double precision only. E.g.
logit, feat = classifier(out, outlen, return_features=True) # (bs, n_classes)
logits[running_n:running_n+logit.shape[0]] = logit.cpu()
logprobs_gen = F.log_softmax(logits, dim=-1)
is_score = inception_score(logprobs_gen)
mis_macro_score = modified_inception_score2(logprobs_gen.double().exp().numpy())
Performing the softmax in the log domain is quite a lot more stable, and should work a bit better hopefully.
I calculated according to the method you provided above, and the result is still 107781.
with torch.no_grad():
logit, feat = classifier(x, x_lens, return_features=True) # (bs, n_classes)
logits.append(logit.cpu())
logits = torch.cat(logits, dim=0)
print("==logits==",logits.shape)
logprobs_gen = F.log_softmax(logits, dim=-1)
is_score = inception_score(logprobs_gen)
raw_prob = logprobs_gen.double().exp().numpy()
mis_macro_score = modified_inception_score2(logprobs_gen.double().exp().numpy())
print(is_score) #tensor(9.1868)
print(mis_macro_score) #107781.92276396013
And I output kl
, many values greater than 11, and I output the probs_gen_1
and probs_gen_2
when kl = scipy.stats.entropy(probs_gen_1, probs_gen_2) >11
, parts of the results are shown in the picture below,
Hi @paopaoyaya , your probabilities look very strange, a porbabilityo f 1.7e-8 is extremely small, and seems a bit strange. What model are you using to generate the output audio? And all of our predictions are extremely confident (over 0.99 for every top class). This seems like something is quite wrong?
For context, here is the probabilities from my classifier for a few samples from my evaluation with the mel-spectrogram model:
tensor([[1.0732e-02, 3.4288e-03, 8.9758e-01, 4.7757e-02, 9.8861e-03, 6.5016e-03,
4.8615e-03, 5.0950e-03, 9.9611e-03, 4.1961e-03],
[5.0862e-03, 9.3507e-01, 2.9997e-03, 4.2973e-03, 1.9337e-02, 6.9634e-03,
3.7869e-03, 6.1457e-03, 2.9259e-03, 1.3385e-02],
[1.9372e-03, 2.1120e-03, 1.0556e-03, 1.3536e-03, 1.4029e-03, 3.0512e-03,
2.0439e-03, 9.8194e-01, 7.6831e-04, 4.3399e-03],
[1.3464e-03, 1.5348e-03, 8.9753e-04, 9.3934e-04, 8.4789e-04, 1.2287e-03,
2.1305e-03, 9.8898e-01, 7.1297e-04, 1.3781e-03],
[9.5318e-01, 2.0061e-03, 2.3010e-02, 4.0712e-03, 4.5172e-03, 1.9784e-03,
3.2767e-03, 3.4033e-03, 1.4486e-03, 3.1067e-03],
[9.7737e-01, 1.5422e-03, 6.7897e-03, 2.3158e-03, 3.2684e-03, 1.3478e-03,
2.3156e-03, 2.0588e-03, 8.8952e-04, 2.1045e-03],
[3.1472e-03, 2.5284e-03, 1.8745e-03, 1.8022e-03, 1.7576e-03, 2.1806e-03,
8.7526e-03, 9.7446e-01, 1.3400e-03, 2.1575e-03],
[7.0372e-04, 1.0072e-03, 5.5194e-03, 3.0350e-03, 1.1002e-03, 1.4809e-03,
2.5923e-03, 8.8620e-04, 9.8216e-01, 1.5172e-03],
[3.7401e-03, 2.3566e-03, 2.2813e-03, 3.1349e-03, 3.0807e-03, 2.1974e-03,
9.5789e-01, 1.9459e-02, 4.2009e-03, 1.6601e-03],
[8.1010e-04, 5.0763e-04, 5.7781e-04, 8.0943e-04, 8.2904e-04, 5.2154e-04,
9.9301e-01, 1.4006e-03, 1.1955e-03, 3.3634e-04],
[1.1460e-03, 1.5012e-03, 1.2414e-03, 1.0011e-03, 9.9062e-01, 1.3833e-03,
9.0572e-04, 7.3270e-04, 7.0613e-04, 7.6705e-04],
[9.7529e-04, 1.3751e-03, 9.8703e-04, 8.1800e-04, 9.9205e-01, 1.3300e-03,
6.7907e-04, 6.1197e-04, 5.1458e-04, 6.5406e-04],
[3.8424e-03, 9.3650e-04, 9.7958e-01, 4.8371e-03, 3.1261e-03, 9.2117e-04,
1.4698e-03, 1.5261e-03, 2.5619e-03, 1.2018e-03],
[1.1211e-03, 1.1069e-03, 6.7681e-04, 7.1772e-04, 6.8634e-04, 9.2723e-04,
1.6964e-03, 9.9159e-01, 5.1073e-04, 9.6363e-04],
[8.7381e-04, 1.0742e-03, 3.1626e-03, 9.8778e-01, 1.0133e-03, 8.6323e-04,
9.2877e-04, 7.3508e-04, 2.1113e-03, 1.4527e-03],
[4.8954e-03, 8.2846e-01, 3.0910e-03, 7.9279e-03, 6.6682e-03, 2.1677e-02,
3.5908e-03, 5.8631e-03, 4.6628e-03, 1.1317e-01],
[9.8981e-01, 9.2747e-04, 1.5020e-03, 1.0977e-03, 1.1930e-03, 7.9200e-04,
1.4560e-03, 1.3753e-03, 4.4361e-04, 1.4046e-03],
[1.3225e-03, 7.3399e-04, 7.9497e-04, 1.3114e-03, 1.1802e-03, 7.5868e-04,
9.8950e-01, 1.7511e-03, 2.1102e-03, 5.3278e-04],
[3.3357e-03, 5.8322e-03, 4.6035e-03, 7.6027e-03, 2.7157e-02, 8.8986e-01,
3.1367e-03, 5.5506e-03, 4.4160e-03, 4.8507e-02],
[4.2567e-04, 1.1502e-03, 3.9536e-04, 6.0441e-04, 9.7480e-04, 9.9315e-01,
7.2432e-04, 6.1858e-04, 8.0289e-04, 1.1573e-03]], dtype=torch.float64)
With my numbers, the model is never so certain, and the certainty in different samples changes (sometimes best class probability is 0.99, sometimes it is 0.82), but your numbers are always over 0.99. It almost looks like you're getting the classifier probabilities with a heavy temperature. Do you have any idea why your distributions have such extremely low entropy?
I used the model you provided: model = torch.hub.load('RF5/simple-speech-commands', 'convgru_classifier', type='best')
. And the above is the raw probabilities I output. Is it because the model I loaded is inconsistent with the one you used? So I wanted to try the ResNext model. When I tried to load the model like: classifier = torch.hub.load('/home/ps/wp/audio_stylegan/simple-asgan/classifier', 'resnext_classifier_sc09', source='local')
, to classify the audio files in sc09-test.csv, and I used the code you provided directly without any changes, but encountered the following error:
Looking forward to your reply! Thanks!
Hi @paopaoyaya , as mentioned in the paper, like the other baselines, we use the ResNeXT classifier for evaluations. I just tested the code again myself, and it works fine. Running this in a fresh google colab returns correct result:
import torch
classifier = torch.hub.load('RF5/simple-speech-commands', 'resnext_classifier_sc09', device='cpu', force_reload=True)
with torch.no_grad():
out = torch.randn(1, 16000) # (batch size, sequence length = 1 second @ 16kHz), the audio you want to classift
outlen = torch.tensor([16000,])
logit, feat = classifier(out, outlen, return_features=True) # (bs, n_classes)
logprobs_gen = logit.log_softmax(dim=-1)
print(logprobs_gen.exp())
# tensor([[0.0839, 0.1124, 0.0913, 0.0959, 0.0837, 0.1012, 0.0961, 0.0980, 0.1033, 0.1342]])
I recommend trying a clean install perhaps, or trying to find why your code is giving strange errors? I can't seem to replicate that error when using the above code and following the steps in the paper? Using the convgru classifier will definitely give you very different numbers
I used the ResNext classifier and got the right value for mIS. I found the above error because some audios in the test set were less than 1s in duration. When I removed these audios, I got a value of about 200 for mIS. Thank you very very much!
I would like to ask if the version of the dataset you used is 0.0.3?
I used v0.0.2 in my experiments :)
I used the ResNext classifier and got the right value for mIS. I found the above error because some audios in the test set were less than 1s in duration. When I removed these audios, I got a value of about 200 for mIS. Thank you very very much!
Hi, let me try answer those questions:
Yes several of the utterances are just just shorter than 1s. For those, I pad it to 1s. I do not remove them.
As mentioned in the paper, 5000:
For the mel-spectrogram ones, the 2e-3 and 2e-4 sounds accurate. They did not change except for as specified in the equalized learning rate technique used in training (and mentioned in the paper)
Hope that helps
My parameter settings remain the same as those provided by you, but the training results obtained are very different from those in the paper. The FAD of training is always not less than 1. I would like to ask where the problem might be?