vincentherrmann / pytorch-wavenet

An implementation of WaveNet with fast generation
MIT License
968 stars 225 forks source link

ZeroDivisionError in dilate() #6

Open prashanth-chandran opened 6 years ago

prashanth-chandran commented 6 years ago

Hi,

First of all, thank you so much for this implementation of wavenet ! :)

When trying to re-train the model, I ran into the following error.

`Traceback (most recent call last): File "train_script.py", line 84, in continue_training_at_step=0)

File "/home/pytorch-wavenet/wavenet_training.py", line 68, in train output = self.model(x)

File "/home/workspace/anaconda2/envs/deeplearn_2.7/lib/python2.7/site-packages/torch/nn/modules/module.py", line 325, in call result = self.forward(*input, **kwargs)

File "/home/pytorch-wavenet/wavenet_model.py", line 188, in forward dilation_func=self.wavenet_dilate)

File "/home/pytorch-wavenet/wavenet_model.py", line 156, in wavenet s = dilate(x, 1, init_dilation=dilation)

File "/home/pytorch-wavenet/wavenet_modules.py", line 25, in dilate new_l = int(np.ceil(l / dilation_factor) * dilation_factor)

ZeroDivisionError: long division or modulo by zero `

I haven't changed any of the parameters in the training script.

The arguments given to the dilate function are as follows:

This happens when the dilate function is called from line 156 of wavenet_model.py. It's a little strange to me as to why the output dilation is lesser than that of init_dilation.

Can you also kindly explain why we dilate when x.size(2) != 1?

Thanks for your help !

Prashanth

vincentherrmann commented 6 years ago

Hi Prashanth, thanks for telling me! Unfortunately I don't have time time to look into it today, but hopefully tomorrow I will:)

prashanth-chandran commented 6 years ago

Hey Vincent, Sure. Thanks :) !

vincentherrmann commented 6 years ago

I can't reproduce this error. But it looks like you're using python2. If yes, then this will probably be the problem (python2 does classic division so 1/2=0, so 1/(1/2) is division by zero). You need python 3 (anaconda3) for this implementation to work - I'll write it in the readme. A quick fix might be to cast the init_dilation as float, but there are probably other things that won't work in python2.

The check for x.size(2) != 1 is just there as a precaution. The dilate function shifts indices from dimension 2 to dimension 0, so if there is just one index in dimension 2 we can't dilate. But this shouldn't happen anyway.