ibab / tensorflow-wavenet

A TensorFlow implementation of DeepMind's WaveNet paper
MIT License
5.41k stars 1.3k forks source link

Model fails to learn sine wave #141

Open atgambardella opened 7 years ago

atgambardella commented 7 years ago

With the default parameters (simply cloning the repo and training), the model could not fit a sine wave (I was intentionally trying to overfit on one of the simplest examples as a sanity check).

The model seems to converge at a loss of 0.001~0.002. The input and output are here: https://soundcloud.com/andrew-gambardella-769364795/sets/sine

Notice that the input sine wave sounds a lot cleaner than the output.

I think it would be useful to get this working before moving onto more complicated examples such as voice.

Is this difference because of the quantization? And is it significant enough to affect the real audio samples?

atgambardella commented 7 years ago

Here are the two waveforms plotted (after appropriate rescaling etc.): http://imgur.com/a/sej9i

Notice that around the peaks, the generated waveform isn't smooth, which I think is what's causing the unclean sound.

Zeta36 commented 7 years ago

@atgambardella, try changing this line in generate.py (line 182):

sample = np.random.choice(np.arange(quantization_channels), p=prediction)

by this:

sample = np.argmax(prediction)

And tell me.

If it doesn't work, change too the line (148):

waveform = np.random.randint(quantization_channels, size=(1,)).tolist()

by the first float value of the sine wave. For example:

waveform = [.5]

I used this method with a overfitted model and I could replicate the original sample. I did it for example in here #125 and in here #129 (I use a image raw input but it the same)

Regards.

atgambardella commented 7 years ago

@Zeta36 I tried your first suggestion - replacing the sample with argmax in the generator - and I'm still getting the same non-smooth behavior at the peaks.

Your second suggestion would align them, but seeding the generator wouldn't change that behavior at the peaks, correct?

My input sine wave is sampled at 44.1kHz -- could this be part of the issue?

Zeta36 commented 7 years ago

In my tests, just when I use argmax and seed the first value of the waveform (both at same time) is when I could have a perfect generation in an overffited model.

Try to figure out the first value of that sine wave and seed it as I told you in the waveform. I think you will have a smooth output in that way.

lemonzi commented 7 years ago

The default settings now use a larger net and train for longer. The former ones were for a toy network. Can you test again to see if it performs better?

atgambardella commented 7 years ago

@lemonzi I tried again with the bigger net and got exactly the same results (but with a lower loss). It might be an issue with quantization + encoding the input as just a one-hot vector. I'm going to run a few experiments with potential fixes and let you know how they went.

Zeta36 commented 7 years ago

@atgambardella, please try a thing: use the whole file at once in the training. I mean, select no sample_size for cutting the audio raw file. I suspect the way the audioreader cut and append to the buffer the raw data may be working wrong. You have more information in one of my comments here #112 . You could try also don't using the fast generation, or don't use the triming silence function.

I'm doing some tests too.

jyegerlehner commented 7 years ago

@atgambardella Another thing you might try in your experiments is "scalar_input": true in the wavenet_params.json. Then there won't be quantization of the input. You might try rmsprop instead of adam, too.

@Zeta36 So does the problem you identified there apply when we are not doing global conditioning? I had understood you to say that the global id and audio chunk get out of sync. But atgambardella isn't doing gc.

Zeta36 commented 7 years ago

@jyegerlehner, but as soon as @atgambardella is using just one wav file, we have still the problem that the last piece of the raw data in the chunk is ignored, and this may be the problem.

Zeta36 commented 7 years ago

I tried all day whit this and I couldn't generate a smooth clear output with the sine wave :(. And it's strange, because I had not this problem with the text or the image version of the WaveNet model. With the text version #117 , I could easily overfit a small dataset of text and generate later an exact (perfect) copy of that text. With the image version of the WaveNet model #129 I could also generate a perfect image after training with a small dataset of images.

So, I guess the problem may be with the librosa library at the time of reading the file, or with the silence function, or may be the u-law function (the image and text version doesn't use this function), or something like that. I cannot figure out another reason for this impossibility when we try to overfit and regenerate perfectly a little small wav file.

I'll stay trying.

Regards.

atgambardella commented 7 years ago

I actually was using the whole file, but in reality it shouldn't matter, because I had two periods of the sine wave (I should have all the information I needed after just seeing half of it).

What we (at my company) were thinking of doing is using a two-hot vector for encoding instead of one-hot. So instead of encoding it as [0, ..., 1, 0, ..., 0] it would be something like [0, ..., a, (1-a), ..., 0] where if in reality the value of the wav file falls between two quantization values, we give the appropriate amount of weight to the surrounding values, instead of pushing it all into one. We would also, of course, try to predict these two-hot values as well (the cross-entropy loss should take care of this during training. For generation we might need to change something). I haven't had the opportunity to try this yet because I've been busy though.

I haven't tried scalar input yet

Zeta36 commented 7 years ago

@atgambardella, I got it!!

The problem concerns with the u-law encode/decode method. I have made this test:

I printed the raw float data [-1, 1] and the u-law encode data [0,255] of the sine.wav, and then I trained the model until I got a ~0.001 loss. Then I generate the ouput and printed the [0,255] data predicted (with np.argmax) by the model and the u-law decode float data [-1,1]:

The thing is that the predicted data [0,255] of the model was exactly the original [0,255] encode data (the model learn to predict perfectly in the range [0,255]), but the float data after the u-law decode wasn't the same as the original float raw data.

So the problem isn't in the model or the train process, but when we try to convert back from the [0,255] predicted output to [-1,1] audio raw data.

In fact, the output is pretty similar (with an error diferencie of around ~0.001 between real and converted float value) but no perfect, and that's the cause of the noise when we save the output wav file.

Regards.

P.D. To get the ~0.001 loss (and overfit well) I had to set EPSILON = 0.0, and we I generate I use np.argmax(prediction) and initialize the waveform with the first quantized real value of the wav aI'm using. In my case the integer 139.

Zeta36 commented 7 years ago

Ok, there is no error with the u-law encode/decoding functions. We just need to set the parameter:

"quantization_channels": 256,

to:

"quantization_channels": 512,

and the output wav will be almost perfect.

Problem is that with a 256 quantization we loss too much information in the codification process to be able to replicate later the exact original source wave.

@atgambardella , if you set quantization_channels to 512 and overfit to a loss of ~0.001 you will generate (almost) the exact original wav sound (but you must use np.argmax() and feed the waveform with the first real value of the original raw data (in [0,255] range).

@jyegerlehner , would be a good idea to set the "quantization_channels": 512 in the master wavenet_params.json? I know that DeepMind paper talks about 256, but I think our results would be much better with 512 and the computacional cost I think it is not so high..

Regards.

P.D. You just take into account this too #155

jyegerlehner commented 7 years ago

@Zeta36 I'd suggest submitting a PR so everyone who has an opinion can discuss it.

Zeta36 commented 7 years ago

The PR is in #155 but the question is easy. The equation for changing ranges is simply: NewValue = (((OldValue - OldMin) * (NewMax - NewMin)) / (OldMax - OldMin)) + NewMin

So, for the decoding fase [0,255] to [-1,1] we have to do: NewValue = (((OldValue - 0) * (1- (-1))) / (255- 0)) + (-1) = (OldValue2)/255 - 1 And we are doing it fine right now: signal = 2 \ (casted / mu) - 1

Encoding fase [-1,1] to [0,255]: NewValue = (((OldValue - (-1)) * (255- 0)) / (1- (-1))) + 0 = ((OldValue+1)255)/2 And we are doing (I don't know why): (signal + 1) / 2 * mu + 0.5 This is wrong. the right way is juts: ( (signal + 1) \ mu) / 2

The code right now is encoding and decoding in different ranges.

Regards, Samu.

jyegerlehner commented 7 years ago

@Zeta36 I was referring to your question about 512 quantization levels instead of 256. PR155 does not change this. I think configurable quantization levels is a good idea for those who need high fidelity, but it does increase the size of tensors and makes us use more memory and run a bit slower. DeepMind paper said 256 was good enough for speech.

jyegerlehner commented 7 years ago

@atgambardella TL;DR: If your sine wave has scaling 1.0 to -1.0, then I suggest multiplying its amplitude by 0.2 or so, and re-running your test.

Rationale The mu-law encoding is inherently low resolution at the extreme ends of its range, and the range is -1 to 1. It can only represent 256 values (with current quantization levels value), and the resolution is concentrated in the middle of the range around zero. If you look at four values it can represent nearest to 1.0, they are: 1.0, 0.957, 0.916, 0.877: those are pretty big step changes between values. By contrast, in the middle near zero we get: -6.44772721e-04 -4.50432097e-04 -2.64362287e-04 -8.62116940e-05: much higher resolution.

So a sine wave that goes all the way to 1, would necessarily have a step change to 0.957 as it decreases. In the frequency domain this step change will manifest as high frequency energy, (pop or hiss). By staying in a more normal range where the encoding is higher resolution, these quantization step changes are much smaller in magnitude and thus induce less audible noise.

I'm not experienced with audio, but I don't think well-produced audio normally gets too close to the extreme of the amplitude range; if it ever goes over, what's that, "clipping"?

lemonzi commented 7 years ago

Audio does get close to the extreme, but not very frequently; the distribution of values is clustered around 0. Actually, most noise models assume that this distribution is gaussian and centered at zero. If it goes over the extreme there's clipping indeed, although all professional systems include limiters and compressors that implement a "softmax" that makes sure this doesn't happen because of spurious peaks.

AFAIK though these logarithmic compressors like mu-law that work with low bit depths are more common in legacy speech-related systems, and nowadays mp3 and the like use linearly-encoded PCM followed by an adaptive quantization in the frequency domain.

El dt., 18 oct. 2016 a les 18:09, jyegerlehner (notifications@github.com) va escriure:

@atgambardella https://github.com/atgambardella TL;DR: If your sine wave has scaling 1.0 to -1.0, then I suggest multiplying its amplitude by 0.2 or so, and re-running your test.

Rationale The mu-law encoding is inherently low resolution at the extreme ends of its range, and the range is -1 to 1. It can only represent 256 values (with current quantization levels value), and the resolution is concentrated in the middle of the range around zero. If you look at four values it can represent nearest to 1.0, they are: 1.0, 0.957, 0.916, 0.877: those are pretty big step changes between values. By contrast, in the middle near zero we get: -6.44772721e-04 -4.50432097e-04 -2.64362287e-04 -8.62116940e-05: much higher resolution.

So a sine wave that goes all the way to 1, would necessarily have a step change to 0.957 as it decreases. In the frequency domain this step change will manifest as lots of high frequency energy, (pop or hiss). By staying in a more normal range where the encoding is higher resolution, these quantization step changes are much smaller in magnitude and thus induce less audible noise.

I'm not experienced with audio, but I don't think well-produced audio normally gets too close to the extreme of the amplitude range; if it ever goes over, what's that, "clipping"?

— You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub https://github.com/ibab/tensorflow-wavenet/issues/141#issuecomment-254654132, or mute the thread https://github.com/notifications/unsubscribe-auth/ADCF5ggQUjvwnnjX5gXUhGPkZ3M4bh6iks5q1UOSgaJpZM4KTN74 .

jyegerlehner commented 7 years ago

Thanks for that info lemonzi.

most noise models assume that this distribution is gaussian and centered at zero.

That makes sense to me. Note that the values of a sine wave are distributed such that most of the probability mass is at the extremes where it changes most slowly; the middle around zero has the least (and where a gaussian distribution would have the most). So I think my suggestion that the amplitude of the sine wave should be scaled into a region where the mu-law has higher resolution is still valid.

atgambardella commented 7 years ago

Wouldn't scaling the sine wave also change how it sounds though? Plus, anything deviating from 0 in an actual audio signal would be hurt significantly if we don't have an automatic fix for this.

What do you think of the two-hot vector idea? Doing this would allow us to have the benefits of quantization, and when generating would allow us to represent values with arbitrary resolution, even close to -1 and 1.

Other than that I think some sort of upsampling or smoothing technique would be preferable to scaling the sine wave, no?

most noise models assume that this distribution is gaussian and centered at zero.

Are we talking about natural-sounding audio (instruments, voice, etc) or noise as in literal noise (white, pink, etc)?

nakosung commented 7 years ago

Why don't we use soft one-hot encoding? We can have (0,0,0.2,0.9,0) instead of (0,0,0,1,0). Having soft one-hot encoding, we can keep residuals generated by quantization. (u-law encoding to 1024 levels and downsample to 256 levels)

jyegerlehner commented 7 years ago

Wouldn't scaling the sine wave also change how it sounds though?

Only how loud it is. I imagine the pitch might be the same.

lemonzi commented 7 years ago

@nakosung That would be interesting to try. Using pure one-hot encoding would be more efficient if we were using sparse_softmax_cross_entropy_with_logits, but we are not using it so there would be no performance loss.