Closed r9y9 closed 6 years ago
Learned attention seems to be almost monotonic. Not working for incremental forward path yet, though.
encoder = Encoder(
n_vocab, embed_dim, padding_idx=padding_idx,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim,
dropout=dropout,
convolutions=((64, 5),) * 10)
decoder = Decoder(
embed_dim, in_dim=mel_dim, r=r, padding_idx=padding_idx,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim,
dropout=dropout,
convolutions=((128, 5),) * 5,
attention=[False, False, False, False, True])
converter = Converter(
in_dim=mel_dim, out_dim=linear_dim, dropout=dropout,
convolutions=((256, 5),) * 5)
model = DeepVoice3(
encoder, decoder, converter, padding_idx=padding_idx,
mel_dim=mel_dim, linear_dim=linear_dim,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim)
Progress: still cannot get correct alignment with incremental forward path. Results are below
Ground truth (mel-spectrogram)
Predicted mel-spectrogram (off-line, feed ground truth every time steps)
Predicted alignment (off-line)
Predicted mel-spectrogram (on-line, feed ground truth at the first time step)
Predicted alignment (on-line)
Feeding zeros to the first time step doesn't work either :(
EDIT: There was a serious bug, fixed https://github.com/r9y9/deepvoice3_pytorch/commit/a0b36a485bfe99ebf258c23057e1d8a602e9e7b2
Still not good, but it's seems start working now.
Predicted mel-spectrogram (off-line, feed ground truth every time steps)
Predicted mel-spectrogram (on-line, start from zero decoder state, forced monotonic attention)
Okay, now my implementation can generate sounds like speech.
def build_deepvoice3(n_vocab, embed_dim=256, mel_dim=80, linear_dim=4096, r=5,
n_speakers=1, speaker_embed_dim=16, padding_idx=None,
dropout=(1 - 0.95)):
encoder = Encoder(
n_vocab, embed_dim, padding_idx=padding_idx,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim,
dropout=dropout,
convolutions=((128, 7),) * 7)
decoder_hidden_dim = 128
decoder = Decoder(
embed_dim, in_dim=mel_dim, r=r, padding_idx=padding_idx,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim,
dropout=dropout,
convolutions=((decoder_hidden_dim, 7),) * 5,
attention=[True, False, False, False, True],
force_monotonic_attention=[True, False, False, False, False])
converter = Converter(
in_dim=decoder_hidden_dim // r, out_dim=linear_dim, dropout=dropout,
convolutions=((256, 7),) * 7)
model = DeepVoice3(
encoder, decoder, converter, padding_idx=padding_idx,
mel_dim=mel_dim, linear_dim=linear_dim,
n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim)
return model
Added dilated convolution support. Seems effective as reported in https://arxiv.org/pdf/1710.08969.pdf.
Th best quality speech sample I can get ever. Still not as much good as Tacotron :(
Some notes:
I've written a short README for how to train a model and how to synthesize audio signals. I would be appreciated if anybody can try and give me feedback.
https://www.dropbox.com/sh/uq4tsfptxt0y17l/AADBL4LsPJRP2PjAAJRSH5eta?dl=0 Added audio samples. WIll update constantly if I can get better samples.
Tried again to visualize mel-spectrogram generated by the model. Compared to https://github.com/r9y9/deepvoice3_pytorch/issues/1#issuecomment-341908167, slowly but it's getting better.
From top to bottom, ground truth, predicted mel-spectrogram and predicted alignment.
It seems the model suffers from learning long time dependency? I'm going to stack more layers, large kernel_size and large dilation factor and see if it works.
@r9y9 On their paper did they provide the number of layers, kernel size and dilation factor? If not, it would be useful to e-mail the authors!
Yes, they provided hyper parameters; number of layers and kernel sizes, etc. I think I did try almost same hyper parameters (unless I didn't misunderstand), but from my experience it didn't work for LJSpeech dataset. I suspect the reason is that the speech samples in LJSpeech dataset have reverberation, resulting in difficult to train high quality model. That's way I'm trying more rich models. e.g., increasing number of layers.
After looking at the mel-spectrogram, I feel like the model seems to have a hard time learing the shifts that happens in the spectra. IE, it seem to generate strait "lines" instead of the more curvy lines in the ground truth.
Perhaps the quality would improve if the shifts could be extracted and provided as an input feature? (I'm thinking it would basically just be an amplitude vector)
However, having to provide this would kinda destroy the TTS side of things. So it is probably not practical, I'm just interested if it would manage to improve the quality at all.
Sorry if this is off topic.
@DarkDefender Thank you for your comment! The straight lines are what I want to improve. Note that if I give a ground truth every time frame to the decoder, I can get curvy lines. So auto-regressive process during decoding is causing the artifacts.
Hi, @r9y9, nice job! Could you upload the training curve? I'm also working on implementing deepvoice3, but with no luck, yet. I think I need to compare yours and mine. Any tips?
@Kyubyong Sure, I will upload logs when I finish my current experiment. Some random tips I have are:
Thanks @r9y9 . I'm trying with another dataset that is much shorter than LJ. And strangely when I applied positional encoding, it didn't work. So I replaced it with positional embedding, and the networks started to learn but not perfectly.
@Kyubyong Here is a screenshot of my tensorboard (during working on #3, so not exactly of deepvoice3).
1.linear_l1_loss
: L1 loss for linear-domain log magunitude spectrogram
mel_l1_loss
: L1 loss for mel-spectrogramdone_loss
: Binary cross entropy loss for done flagattn_loss
, mel_binary_div_loss
and linear_binary_div_loss
come from https://arxiv.org/abs/1710.08969. However, without those losses I can get similar loss curves. After training 20 hours, I think I get good speech samples https://www.dropbox.com/sh/jkkjqh6pawkg6sd/AAD5--NRm4rRgHo91sHYMAvGa?dl=0.
Amazing. So from the beginning, you could get monotonic alignments as can be seen above in this page, right? Is is thanks to the shared initial weights of key and query projections? If that's the case, could you point out its implementation in you code?
The paper says,
"We initialize the fully-connected layer weights used to compute hidden attention vectors to the same values for the query projection and the key projection. "
@Kyubyong,
Amazing. So from the beginning, you could get monotonic alignments as can be seen above in this page, right?
Yes.
Is is thanks to the shared initial weights of key and query projections? If that's the case, could you point out its implementation in you code?
I haven't tried same weight initialization for attention because I thought it's not quite important. Attention works without it. Will try shared initial weights next, thanks!
I think that your new samples sounds quite good! To me, they sound more "clear" (not as muffled) than the tacotron samples posted on keithito tacotron page. However while yours are clearer, they do have more of the "heavy data compression"/vibration going on in them.
BTW, are you using a post processing network as the one in tacotron? I'm asking because your samples reminds me a bit of the “Tajima Airport serves Toyooka.” sample on: https://google.github.io/tacotron/publications/tacotron/index.html
It you do not, then perhaps the vibration effect could be eliminated with a post processing network?
@DarkDefender Thank you for your feedback! For "heavy data compression", it might be the reason I used half dimension of linear spectrogram. i.e., FFT size 1024 (hparams.py#L73) instead of 2048.
BTW, are you using a post processing network as the one in tacotron?
Yes.
Yes, the FFT size might be it! I've recently taken a audio compression course and, to me at least, it sounds like some of the examples where we transformed the audio signal with FFT to the frequency domain and then quantized the frequency constants a bit too much.
But I also guess that it might be that the network didn't learn how to put the audio signal together without procuding these sound artifacts.
Sadly I do not have a nvidia GPU so I can't test upping the FFT size and see if the quality goes up...
Edit: BTW, thank you for answering my silly questions @r9y9
Edit2: Actually, if the FFT size in this case only refers to sample chunk lenght, then it might not improve that much by increasing it. You will only get more constants to work with when in freq domain... If the neural network sound artifacts indeed are from disconnects that appear between sample chunks, then uping the chunk size will only space them out more (not eliminate them).
@DarkDefender I really appreciate your comments!. Yeah, I hope network can produce good speech samples even with small FFT size. I will do more experimetns with larger FFT sizses.
One thing I have found new today is that L1 Loss for spectrogram decreases more quickly using decoder internal states for postnet inputs, rathar than using mel-spectrogram. Hopefully this improves speech quality a bit.
@r9y9 I'm guessing that the new 380 000 check point samples are with the new postnet inputs? I feel like the speech samples has improved a bit in quality regardless.
The only thing that got worse was the 3_checkpoint_step000380000.wav
sample. I'm guessing that the pause between repeal and replace
perhaps might go away with more training?
Did changing FFT size do anything noticeable BTW?
@r9y9 you mentioned using decoder internal states as postnet inputs. Is this something described in the paper?
@rafaelvalle Yes, it's mentioned in DeepVoice3, not in https://arxiv.org/abs/1710.08969 though.
@DarkDefender New samples are from model https://arxiv.org/abs/1710.08969 with using decoder internal states as postnet inputs. https://github.com/r9y9/deepvoice3_pytorch/commit/22a674803f2994af2b818635a0501e4417834936. L1 Loss for spectrogram decreases more quickly.
I'm guessing that the pause between repeal and replace perhaps might go away with more training?
I'm guessing that too.
Did changing FFT size do anything noticeable BTW?
Haven't tried it yet:(
https://github.com/r9y9/deepvoice3_pytorch#pretrained-models
Pre-trained models are now up and ready.
Finally, I think I get reasonable (not very good though) speech quality with a multi-speaker model trained on VCTK (108 speakers). Speech samples for different three speakers are attached:
p225
p225
p236
WIP at #10
Thanks for the update! I really appreciate like that you give us small updates every now and then.
As you said, it seems like they are able to read the text correctly but the quality of the sound is not that good. But it is nice results none the less.
Yeah, the result is not very good, but since I had very hard time to make multi-speaker model to actually work using VCTK, this is a big progress to me:)
Amazing work!
I've tried training with Nyanko on the current commit (8afb609) with the LJSpeech data set and default Hparams, but get significantly worse performance than the examples you have posted (for an earlier commit).
4357976 was a huge commit, so I'm going through to see if anything might have negatively affected single voice performance. I thought I'd also check to see if you have any suggestions.
Here are eval examples from 325000 and 665000 training steps. https://www.dropbox.com/sh/rarwoxl3u0f5qkn/AAALH_XayWwEuBoN5bz1P3BIa?dl=0
@patHutchings
I've tried training with Nyanko on the current commit (8afb609) with the LJSpeech data set and default Hparams, but get significantly worse performance than the examples you have posted (for an earlier commit).
Sorry about that. I might have introduced a bug in #6 (maybe https://github.com/r9y9/deepvoice3_pytorch/pull/6/commits/4d9bc6f6631569e81ecce4c37dcac6111975b869 doesn't work for Nyanko architecture). I will take a look soon.
I think I have reverted the (possibly affected) changes. It should get same results as I posted previously. Will try to train new models as soon as possible.
OMG, I introduced very stupid bug! Should be fixed by https://github.com/r9y9/deepvoice3_pytorch/commit/a934acd870bd09ad7e9240953f78e86bf6378451
@r9y9 Yep that will do it. Sorry I should have noticed that myself. I'll train again (and also with another, custom dataset) and share models.
Ok, trained 150k steps, seems working as before. Dropout was essential to make it work.
@r9y9 , I downloaded the pretrained model , but when I ran getting the following error , doesn't it support CPU inference ?
python synthesis.py pretrainedmodel.pth test_list.txt output/ Command line args: {'--checkpoint-postnet': None, '--checkpoint-seq2seq': None, '--file-name-suffix': '', '--help': False, '--hparams': '', '--max-decoder-steps': '500', '--output-html': False, '--replace_pronunciation_prob': '0.0', '--speaker_id': None, '<checkpoint>': 'pretrainedmodel.pth', '<dst_dir>': 'output/', '<text_list_file>': 'test_list.txt'} Traceback (most recent call last): File "synthesis.py", line 124, in <module> checkpoint = torch.load(checkpoint_path) File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 231, in load return _load(f, map_location, pickle_module) File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 379, in _load result = unpickler.load() File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 350, in persistent_load data_type(size), location) File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 85, in default_restore_location result = fn(storage, location) File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 67, in _cuda_deserialize return obj.cuda(device_id) File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/_utils.py", line 58, in _cuda with torch.cuda.device(device): File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/cuda/__init__.py", line 125, in __enter__ _lazy_init() File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/cuda/__init__.py", line 84, in _lazy_init _check_driver() File "/home/saurabh/anaconda3/lib/python3.6/site-packages/torch/cuda/__init__.py", line 51, in _check_driver raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled
@saurabhvyas Sorry, CPU inference is not currently supported yet.
@r9y9 Okay no problem , but good project :)
Add a brief guide for speaker adaptation and training multi-speaker model. See https://github.com/r9y9/deepvoice3_pytorch#advanced-usage if interested.
Audio samples are now available at the github page: https://r9y9.github.io/deepvoice3_pytorch/
I finished everything what I wanted to do initially. I will close this issue and create separate ones for specific issues.
Learned attention seems to be almost monotonic. Not working for incremental forward path yet, though.
encoder = Encoder( n_vocab, embed_dim, padding_idx=padding_idx, n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim, dropout=dropout, convolutions=((64, 5),) * 10) decoder = Decoder( embed_dim, in_dim=mel_dim, r=r, padding_idx=padding_idx, n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim, dropout=dropout, convolutions=((128, 5),) * 5, attention=[False, False, False, False, True]) converter = Converter( in_dim=mel_dim, out_dim=linear_dim, dropout=dropout, convolutions=((256, 5),) * 5) model = DeepVoice3( encoder, decoder, converter, padding_idx=padding_idx, mel_dim=mel_dim, linear_dim=linear_dim, n_speakers=n_speakers, speaker_embed_dim=speaker_embed_dim)
hi @r9y9 ! i just want to know how you generate the gif time lapse of the the alignment ya?
Single speaker model
Data: https://keithito.com/LJ-Speech-Dataset/
Audio samples (jp)Multi-speaker model
Data: VCTK
Misc
[x] Char and phoneme mixed inputs
[x] Japanese text-processing frontend
[x] Try Japanese TTS using https://sites.google.com/site/shinnosuketakamichi/publication/jsut
[x] Implement dilated convolution
[x] preprocessor for jsut
[x] Integrate https://github.com/lanpa/tensorboard-pytorch and log images and audio samples
[x] Add instructions how to train models (en/jp)
[x] Rewrite audio module for better spectrogram representation. Replace griffin lim with https://github.com/Jonathan-LeRoux/lws.
[x] Create github pages with speech samples
From https://arxiv.org/abs/1710.08969
[x] Guided attention
[x] Downsample mel-spectrogram / upsample converter
[x] Binary divergence
[x] ~Separate training for encoder+decoder and converter~
Notes (to be moved to README.md)