bshall / knn-vc

Voice Conversion With Just Nearest Neighbors
https://bshall.github.io/knn-vc/
Other
450 stars 65 forks source link

Training HiFiGAN on higher quality data #11

Closed gaurangbharti1 closed 1 year ago

gaurangbharti1 commented 1 year ago

Hey, I was wondering what sort of changes it would take to the training script to be able to train HiFiGAN on higher quality data like LibriTTS or LibriTTS-R. The dataset uses wav files instead of flac files and is 24kHz sampling rate. I can preprocess the dataset to be 16kHz and make changes to the files in data_splits to work with wav files, but I wanted to know what the best way to work with this kind of data would be. If there are other ways to help improve the general quality of the outputs, I'd be happy to explore those too. Any help would be great, thanks!

RF5 commented 1 year ago

Hi @gaurangbharti1 , thanks for your interest!

You can definitely retrain the HiFiGAN on LibriTTS, and that might give a small quality improvement. I suspect, however, that the main limitation on quality is not the speech data that the HiFiGAN was trained on, but rather the mismatch between adjacent frames. i.e. HiFiGAN can vocode pure WavLM features very well, but once we map them using k-nearest neighbors, there can be artifacts between adjacent frames, making the vocoding task harder. This is why training using prematched data helps quite a bit, but the effect is still there.

In other words, the quality is not limited by the train-clean-100 of librispeech, but by the act of mixing features together from disparate points in time with the matching operation. Also WavLM was trained on librilight I think, which is of similar quality to librispeech, so if you truly want to retrain everything on high quality 24kHz, you might also need to retrain the WavLM encoder.

But, if you just wish to retrain the HiFiGAN on LibriTTS or similar, there are two ways you could do it: (1) is as you say, resample the dataset to 16kHz and change file extensions to .wav files; the other way is (2) to change the HiFiGAN architecture to vocode to 24kHz utterances instead of 16kHz utterances and train directly to vocode to 24kHz -- however this way still requires you to resample the input to 16kHz to compute the WavLM features since it is trained on 16khz audio.

I hope that helps!

gaurangbharti1 commented 1 year ago

Hey @RF5, thanks for your response and explanation! I understand the bottlenecks a lot more.

I still wanted to give retraining HiFiGAN a shot just to see how it would impact the outputs. I'm planning on using the LibriTTS-R dataset, which was made public just about a month ago and seems to have great performance improvements compared to LibriTTS and other similar datasets. It would also be a good learning experience. I think the 1st method you mentioned would require the least effort of the two so I figured I'd start there and see how things go.

I resampled the audio files from the dataset to 16kHz and wrote a script to remake the files in data_splits to match the data in the LibriTTS-R dataset, and changed flac to wav in prematch_dataset.py. However, I ran into an error when I ran this: python prematch_dataset.py --librispeech_path LibriTTS_R --out_path training_outputs --topk 4 --matching_layer 6 --synthesis_layer 6 --prematch:

Traceback (most recent call last):
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 171, in <module>
    main(args)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 49, in main
    extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 110, in extract
    if args.fast_l2: pb = progress_bar(df.iterrows(), total=len(df))
AttributeError: 'Namespace' object has no attribute 'fast_l2'

I changed line 110 to just be pb = progress_bar(df.iterrows(), total=len(df)) and commented out the else statement to bypass it, although I'm sure there's a better way to deal with it. After bypassing it, these are the output logs:

        0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
Synthesis weightings: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
[LIBRISPEECH] Computing folders ['train-clean-100', 'dev-clean']
Loading wavlm.
WavLM-Large loaded with 315,453,120 parameters.
Feature has shape:  torch.Size([300, 1024])-----------------------------------------------------------| 0.00% [0/38968 00:00<?]
Done 0/38,968
Feature has shape:  torch.Size([105, 1024])---------------------------| 0.00% [1/38968 00:08<93:05:25 train-clean-100/8629/261139/8629_261139_000027_000000.wav]
Feature has shape:  torch.Size([660, 1024])---------------------------| 0.01% [2/38968 00:13<72:14:33 train-clean-100/8629/261139/8629_261139_000038_000000.wav]
Traceback (most recent call last):-----------------------------------| 0.71% [277/38968 00:47<1:49:44 train-clean-100/4267/287369/4267_287369_000008_000000.wav]
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 171, in <module>
    main(args)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 49, in main
    extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 127, in extract
    matching_pool, synth_pool = path2pools(row.path, wavlm, match_weights, synth_weights, device)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 73, in path2pools
    matching_pool = torch.concat(matching_pool, dim=0)
RuntimeError: torch.cat(): expected a non-empty list of Tensors

I'm not very sure what would be causing this - this error didn't occur when using the recommended LibriSpeech dataset. Any help with this would be appreciated. Thanks!

RF5 commented 1 year ago

Hi @gaurangbharti1 , this is likely a problem with your dataset -- it looks like the 277th file has a bad speaker id, or is the only utterance from that speaker. Without knowing the details with how your dataset files are setup and how many utterances are present for each speaker, I can't know precisely what is wrong. But if I had to guess it is that some speakers only have 1 utterance associated with them, which will cause prematching to fail (since it assumes multiple utterances per speaker).

Hope that helps!

fangg2000 commented 1 year ago

Hey @RF5, thanks for your response and explanation! I understand the bottlenecks a lot more.

I still wanted to give retraining HiFiGAN a shot just to see how it would impact the outputs. I'm planning on using the LibriTTS-R dataset, which was made public just about a month ago and seems to have great performance improvements compared to LibriTTS and other similar datasets. It would also be a good learning experience. I think the 1st method you mentioned would require the least effort of the two so I figured I'd start there and see how things go.

I resampled the audio files from the dataset to 16kHz and wrote a script to remake the files in data_splits to match the data in the LibriTTS-R dataset, and changed flac to wav in prematch_dataset.py. However, I ran into an error when I ran this: python prematch_dataset.py --librispeech_path LibriTTS_R --out_path training_outputs --topk 4 --matching_layer 6 --synthesis_layer 6 --prematch:

Traceback (most recent call last):
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 171, in <module>
    main(args)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 49, in main
    extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 110, in extract
    if args.fast_l2: pb = progress_bar(df.iterrows(), total=len(df))
AttributeError: 'Namespace' object has no attribute 'fast_l2'

I changed line 110 to just be pb = progress_bar(df.iterrows(), total=len(df)) and commented out the else statement to bypass it, although I'm sure there's a better way to deal with it. After bypassing it, these are the output logs:

        0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
Synthesis weightings: tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
[LIBRISPEECH] Computing folders ['train-clean-100', 'dev-clean']
Loading wavlm.
WavLM-Large loaded with 315,453,120 parameters.
Feature has shape:  torch.Size([300, 1024])-----------------------------------------------------------| 0.00% [0/38968 00:00<?]
Done 0/38,968
Feature has shape:  torch.Size([105, 1024])---------------------------| 0.00% [1/38968 00:08<93:05:25 train-clean-100/8629/261139/8629_261139_000027_000000.wav]
Feature has shape:  torch.Size([660, 1024])---------------------------| 0.01% [2/38968 00:13<72:14:33 train-clean-100/8629/261139/8629_261139_000038_000000.wav]
Traceback (most recent call last):-----------------------------------| 0.71% [277/38968 00:47<1:49:44 train-clean-100/4267/287369/4267_287369_000008_000000.wav]
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 171, in <module>
    main(args)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 49, in main
    extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 127, in extract
    matching_pool, synth_pool = path2pools(row.path, wavlm, match_weights, synth_weights, device)
  File "/home/ubuntu/knn-vc/knn-vc/prematch_dataset.py", line 73, in path2pools
    matching_pool = torch.concat(matching_pool, dim=0)
RuntimeError: torch.cat(): expected a non-empty list of Tensors

I'm not very sure what would be causing this - this error didn't occur when using the recommended LibriSpeech dataset. Any help with this would be appreciated. Thanks!

when I run hifigan.train, it has the same exception,create new conda pro, pytorch version >= 2.0, all packeges install with conda command, must by conda command^_^

RF5 commented 1 year ago

when I run hifigan.train, it has the same exception,create new conda pro, pytorch version >= 2.0, all packeges install with conda command, must by conda command^_^

I unfortunately do not quite understand what you are saying here, but it sounds like you came right?

gaurangbharti1 commented 1 year ago

Hey @RF5 thanks a lot for your help! I was able to find a way to workaround the error from earlier by checking the size of matching_pool and skipping the instances where the size was 0. This caused about 20-30 files to be skipped, but as there are >38,000 files, I don't expect the performance to take much of a hit.

Through this, I was able to complete the first step of precomputing the WavLM features for the LibriTTS-R dataset and was able to start training. However, I ran into another error after training started.

After running this command: python -m hifigan.train --audio_root_path LibriTTS_R --feature_root_path training_outputs --input_training_file data_splits/wavlm-hifigan-train.csv --input_validation_file data_splits/wavlm-hifigan-valid.csv --checkpoint_path checkpoints --fp16 False --config hifigan/config_v1_wavlm.json --stdout_interval 25 --training_epochs 1800 --fine_tuning, training started and 1 epoch was successfully trained, but there was an error while computing the evaluation loss:

checkpoints directory :  checkpoints
/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torchaudio/transforms/_transforms.py:611: UserWarning: Argument 'onesided' has been deprecated and has no influence on the behavior of this module.
  warnings.warn(
Epoch: 1
Steps : 0, Gen Loss Total : 108.426, Mel-Spec. Error : 2.217, sec/batch : 9.485, peak mem: 12.33GB------| 0.00% [0/1660 00:00<?]
validation run complete at 0 steps. validation mel spec error: 2.1536                                                                    
Traceback (most recent call last):                                                                                                                  
  File "/opt/conda/envs/knn-vc/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/knn-vc/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/knn-vc/knn-vc/hifigan/train.py", line 333, in <module>
    main()
  File "/home/ubuntu/knn-vc/knn-vc/hifigan/train.py", line 329, in main
    train(0, a, h)
  File "/home/ubuntu/knn-vc/knn-vc/hifigan/train.py", line 136, in train
    for i, batch in pb:
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/fastprogress/fastprogress.py", line 50, in __iter__
    raise e
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/fastprogress/fastprogress.py", line 41, in __iter__
    for i,o in enumerate(self.gen):
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/envs/knn-vc/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/ubuntu/knn-vc/knn-vc/hifigan/meldataset.py", line 191, in __getitem__
    mel_start = random.randint(0, mel.size(1) - frames_per_seg - 1)
  File "/opt/conda/envs/knn-vc/lib/python3.10/random.py", line 370, in randint
    return self.randrange(a, b+1)
  File "/opt/conda/envs/knn-vc/lib/python3.10/random.py", line 353, in randrange
    raise ValueError("empty range for randrange() (%d, %d, %d)" % (istart, istop, width))
ValueError: empty range for randrange() (0, 0, 0)

I didn't make any changes to config_v1_wavlm.json except increasing the batch_size to 20. Any help with this would be great, thanks!

RF5 commented 1 year ago

Hi @gaurangbharti1 , this looks like your data is corrupted again. It is likely that one of your features did not save correctly and don't have at least a sequence length of 1, so when it tries to split them it does not work. I recommend double checking that all the saved .pt feature files are saved correctly and without corruption. A good way to test this is to see if there are any saved files with unreasonably small file sizes (less than 2kb), and delete those ones. Also remember to ensure your data splits (wavlm-hifigan-train.csv and wavlm-hifigan-valid.csv) are updated with the latest paths for your dataset and don't include those bad speakers you skipped during preprocessing.

Hope that helps!

gaurangbharti1 commented 1 year ago

Hi @RF5, thanks for your help! I was able to successfully clean the dataset up and start training! Interestingly enough, all the .pt files were >2kb, the smallest ones were about 20kbs, but some of the audio samples were very short, less than 0.75 seconds long and not very clear. My assumption is that WavLM was able to extract some features from them, but not nearly enough to be able to do any further processing on it. I wrote a separate script to remove files that were very short from wavlm-hifigan-train.csv and wavlm-hifigan-valid.csv which seemed to do the trick. There were only about 650-700 files to remove.

It's currently on ~50k steps and the loss is steadily decreasing! This will probably take a few days to fully train so I'll let you know how the quality is after it's done or in case I run into any more issues. Thanks again!

RF5 commented 1 year ago

Great, good luck! Hope it sounds good soon :)

youssefabdelm commented 1 year ago

This is why training using prematched data helps quite a bit, but the effect is still there.

@RF5 Thanks so much for your work, quite impressive. I'm curious when you say "prematched data" are you referring to training / fine-tuning some or one of the models on the specific speaker we'd like the convert the voice to? E.g. if I wanted to sound like person X, I need to fine-tune on some of X's voice data.

If so, which models specifically would you fine-tune, HiFiGAN + WavLM or one of them?

youssefabdelm commented 1 year ago

I'm also considering training both models on 44.1/48kHz data so thanks @gaurangbharti1 for opening this issue!

RF5 commented 1 year ago

@RF5 Thanks so much for your work, quite impressive. I'm curious when you say "prematched data" are you referring to training / fine-tuning some or one of the models on the specific speaker we'd like the convert the voice to? E.g. if I wanted to sound like person X, I need to fine-tune on some of X's voice data.

Hi @youssefabdelm , by 'prematched data' I am referring to Section 3.4 of the arxiv paper: training HiFi-GAN on WavLM features which have already had kNN applied to them. i.e. we are not talking about fine-tuning HiFi-GAN to data from a single target speaker, just data which has had kNN applied to them (see Sec 3.4 of the paper for more info). This helps reduce artefacts caused by the kNN operation.

The single HiFi-GAN should work for all speakers, but naturally fine-tuning it on a single speaker will likely improve performance on that speaker. But in short: no, if you want it to sound like person X, you do not need to fine-tune on specific data from speaker X.

If so, which models specifically would you fine-tune, HiFiGAN + WavLM or one of them?

WavLM is massive and trained on tons and tons of data, HiFi-GAN is likely the limiting thing in performance -- if you fine-tune anything, it should probably be the HiFI-GAN checkpoint. Good luck training/fine-tuning :)

youssefabdelm commented 1 year ago

@RF5 Thanks so much! Yeah I had a feeling you might be talking about the paper. Read that part and a lot clearer now.

WavLM is massive and trained on tons and tons of data, HiFi-GAN is likely the limiting thing in performance -- if you fine-tune anything, it should probably be the HiFI-GAN checkpoint. Good luck training/fine-tuning :)

Good to note, thank you! Will give 44.1kHz a shot and will share (non-speaker-specific) models here if I'm successful

lenzo-ka commented 1 year ago

Hi -- @youssefabdelm can you share your process? I'm interested in doing this as well

gaurangbharti1 commented 1 year ago

Hey @RF5 - sorry for the delay on this, got a little caught up with a few other things. HiFiGAN's completed training for 1800 epochs on the newer dataset, but unfortunately the audio quality did not sound any better. There's a fork on my profile with the checkpoints in the Releases. For reference, to try running knn-vc with my trained Vocoder, you can swap knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True) with knn_vc = torch.hub.load('gaurangbharti1/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True) and it should download and setup the weights. Like I mentioned though, I didn't see any major improvement in the quality. For anyone who would like to explore higher quality voice conversion, one avenue could be to use an upsampling model like NU-Wave 2 and other similar models to help increase the sampling rate and general quality of the output. I haven't tried it myself but could have some potential. Thanks again for your help!

RF5 commented 1 year ago

Hi @gaurangbharti1 , thanks for the results! Sorry that the quality didn't improve much, but good to know that it might not be the dataset that is the current limiting factor in synthesis quality.

Glad you managed to reproduce it with your own dataset to a similar quality as our efforts, it is good to know the method is robust across different datasets. Thanks again for the effort!