sh-lee-prml / HierSpeechpp

The official implementation of HierSpeech++
MIT License
1.13k stars 134 forks source link

discriminator checkpoint #4

Open iehppp2010 opened 7 months ago

iehppp2010 commented 7 months ago

Hi, Could you share your discriminator checkpoint? So that we can try fine-tune the HAG model on other language. I tried user your HAG checkpoint model to Resynthesis chinese mandarin audio, the model can do it well at most of the time when input MMS wav2vec feature of lower speech rate audio.

Based on the design of your model, the HAG model was not very related to language it trained on, if I'm not get wrong.

Before you release your model train code, I want to try train the TTV model on mandarin and english data from scratch, and try fine-tune the HAG model on my data to get better pronunciation. Most importantly, Thans for your sharing code and checkpoint!

sh-lee-prml commented 7 months ago

Thanks for your interest!

As you said, our hierarchical speech synthesizer could synthesize any language speech robustly because we utilize a continuous self-supervised speech representation, not phonetic-related information such as PPG.

I have attached checkpoints of generator and discriminator (v1.1) including optimizer. link

You can load the model by

    _ = utils.load_checkpoint("ckpt_name", net_g, None)
    _ = utils.load_checkpoint("ckpt_name", net_d, None)

if you want to load the optimizer together, use this

    optim_g = torch.optim.AdamW(
        net_g.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps)
    optim_d = torch.optim.AdamW(
        net_d.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps)
    _ = utils.load_checkpoint("ckpt_name", net_g, optim_g )
    _ = utils.load_checkpoint("ckpt_name", net_d, optim_d )

However, I have concerns about fine-tuning the model because we could not share the training code in this stage...

Although we describe the training details, training code must be elaborate where it may hard to reproduce some components such as speech perturbation, w2v feature extraction, and Mel-preprocessing (https://github.com/sh-lee-prml/HierSpeechpp/blob/main/Mels_preprocess.py).

It is very important to remove speaker-related information by speech perturbation. We would appreciate your understanding. We will share this part after paper acceptance...!

In addition, we are still training the hierarchical speech synthesizer now.

image

The generator v1.1 is a checkpoint on 1,890k steps. The model is not converged yet! We will share the checkpoint after training more.

iehppp2010 commented 7 months ago

Thanks for your impressive response and generous sharings. Indeed, in my poor understanding, the speech perturbation is the key point to achieve the two-stage prosody style and voice style modeling.When train the TTV model, I tried use this un-official implement speech perturbation of NANSY++.

Below is the loss curves trained on 1300 hour chinese and english mixed dataset with four RTX A6000 GPU. The loss_w2v(wav2vec loss) scale is 45(I missed the 'c_mel' config in your json config file, I guess it should be the wav2vec reconstruction loss scale?), other model hyper param keeps the same as your 'ttv_libritts_v1' model configuration, except the batch size and learning rate due to lower GPU RAM.

image

one validation predict result Input text is '请播一个刘欢的专辑' (Please play a Liu Huan’s album) Wavform was synthesized by the 'hierspeechpp_v1.1_ckpt.pth' model, without fine-tune on chinese data.

During the training process, I found the loss_w2v doesn't keep to converge to smaller value after initial train stage. Have you do some post-processing to wav2vec features, e.g. normalization?

And an addtional question is about the posterior encoder input feature.

image

Is it the wav2vec feature of perturbated waveform? Because the TTV model should be voice style irrelated to focus on prosody style modeling.

sh-lee-prml commented 7 months ago

Sorry for the confusion.

We do not utilize the speech perturbation when training TTV and the weight of w2v reconstruction is 10!

 "c_mel": 10
this is the weight of w2v reconstruction. Sorry for the confusion.

https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/config.json.

Here is our loss curve.

image

As you said, our w2v loss has a similar result because w2v features are not normalized. (Some value of w2v features have too big value above 500...!)

However, it is okay because we train our hierarchical speech synthesizer with perturbed w2v features so the synthesizer robustly generates the speech!!!

In addition, we are trying to modify the TTV architecture to further reduce this issue. (including hidden size and decoder hidden shapes)

Give permission for audio samples!

Thanks!

iehppp2010 commented 7 months ago

Sorry, I fixed the permission issue for audio samples. I think model still is under-fitting, so there is mispronounce.

The TTV model, which I'm traning now, use perturbated w2v features as posterior. Because some voice style information, e.g. timber, may exist in w2v features. So, I want to know for what reason you choose not perturbated w2v features...

We do not utilize the speech perturbation when training TTV

sh-lee-prml commented 7 months ago

As you said, the w2v representation contains some voice style information.

There are some reason I do not use speech perturbation in this part,

  1. We observed that the KL divergence loss between text and w2v is relatively high, resulting in optimization issue.
  2. Because the hierarchical speech synthesizer is trained with perturbed speech, we thought that voice style is adapted with voice prompt in hierarchical speech synthesizer newly.
  3. Although w2v representation contains some voice style, they have relatively little information as we observed in HierSpeech. image

However, I agree that using speech perturbation in this part could improve the prosody style transfer performance by removing voice style if the model is robustly trained.

and I have a question about your TTV setting which uses a Chinese and English dataset. Do you use a language embedding in text encoder? In my experience, adding language embedding in first block of text encoder improves the pronunciation!

iehppp2010 commented 7 months ago

Thanks for your response for my question again!

I don't use language embedding yet. I'm using the mixed phoneme set. Chinese phonemes include pinyin initial consonants and finals. English phonemes from IPA.

sh-lee-prml commented 7 months ago

I think you don't need language embedding in your case.

If you have other questions, feel free to ask me!

azraelkuan commented 7 months ago

hi, @sh-lee-prml, how about the ctc loss weight? Is it 45 the same with hierspeech?

sh-lee-prml commented 7 months ago

@azraelkuan Yes It is same with hierspeech

 "c_pho": 45.0,

This is the ctc loss weight (https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/config.json.)

azraelkuan commented 7 months ago

@sh-lee-prml thanks! Sorry, forgot to check the config file!!!

sh-lee-prml commented 7 months ago

@azraelkuan My pleasure! If you have other questions, feel free to ask me :)

azraelkuan commented 7 months ago

@sh-lee-prml i have found that YAPPT is not accurate for many samples in libritts, it will give all zero result...

sh-lee-prml commented 7 months ago

In my case, about 10 samples of LibriTTS-train-subsets were extracted with all zero results and we removed these files in training filelists.

Much more specifically, we use a default setting of YAPPT for fmin and fmax.

‘f0_min’ - minimum pitch searched (default: 60 Hz)
‘f0_max’ - maximum pitch searched (default: 400 Hz)
def get_yaapt_f0(audio, rate=16000, interp=False):
    frame_length = 20.0
    to_pad = int(frame_length / 1000 * rate) // 2

    f0s = []
    for y in audio.astype(np.float64):
        y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0)
        signal = basic.SignalObj(y_pad, rate)
        pitch = pYAAPT.yaapt(signal, **{'frame_length': frame_length, 'frame_space': 5.0, 'nccf_thresh1': 0.25,
                                        'tda_frame_length': 25.0})
        if interp:
            f0s += [pitch.samp_interp[None, None, :]]
        else:
            f0s += [pitch.samp_values[None, None, :]]

    f0 = np.vstack(f0s)
    return f0

This is an example

            audio, sr = torchaudio.load(audio_path)

            p = (audio.shape[-1] // 320+ 1) * 320- audio.shape[-1]
            audio = torch.nn.functional.pad(audio, (0, p), mode='constant').data

            try:
                f0 = get_yaapt_f0(audio.numpy())
            except:
                f0 = np.zeros((1, 1, audio.shape[-1] // 80))

            f0 = f0.astype(np.float32)
            f0 = f0.squeeze(0)
            f0 = torch.FloatTensor(f0)
            torch.save(f0, f0_filename)

and before fed to model, we normalize it log-scale by

        f0 = torch.log(f0+1)

For expressive dataset, we increase the fmax to 800 for training dataset. (Additionally, we will utilize 1100 for singing dataset.)

Meanwhile, YAPPT is too slow to extract all samples when we increase the dataset so now we are testing other pitch extraction method such as praat.

Thanks!

iehppp2010 commented 7 months ago

strongly recommend RMVPE to extract pitch, and it's more noise robust and accurate

azraelkuan commented 6 months ago

@sh-lee-prml hi, i found that you missed a layer norm in wav2vec processor, will it have any influence for result? https://github.com/huggingface/transformers/blob/371fb0b7dc1b533917e2f85b464a3ec9c74f28b9/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L227

sh-lee-prml commented 6 months ago

@sh-lee-prml hi, i found that you missed a layer norm in wav2vec processor, will it have any influence for result? https://github.com/huggingface/transformers/blob/371fb0b7dc1b533917e2f85b464a3ec9c74f28b9/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L227

Thanks for your suggestion.

Actually, we directly use the hidden representation from the 7th layer of MMS (Wav2Vec 2.0) without layer-norm and we do not experience any problem yet.

However, this may improve the performance of some downstream tasks. I also observed that the hidden representation of some data has too large value so this could improve the robustness of our model.

Thanks!

meriamOu commented 4 months ago

hey, thank you for sharing the training cure. I am training the model on Libritts 960 using exactly your preprocessing and training files. However my training curve doesn t show the same losses you got. while you reached 2.74 w2v loss at 100k, my training is still at 6. The model shows slurries as well

Screenshot 2024-02-04 at 12 51 25 PM Screenshot 2024-02-04 at 2 02 02 PM

Do you know what could be the reason ? thank you