r9y9 / wavenet_vocoder

WaveNet vocoder
https://r9y9.github.io/wavenet_vocoder/
Other
2.3k stars 500 forks source link

Could someone explain how these dimensions of Wavenet layers work, as they appear to mismatch? #200

Closed jjoe1 closed 4 years ago

jjoe1 commented 4 years ago

@r9y9 I was trying to understand the wavenet-vocoder implementation and some of the layer dimensions didn't seem to match based on what I understood from the wavenet paper.

Could you shed light on some of these dimensions as may be I'm missing something?

1) The first_conv layer shows as shaped 1x512. Isn't the input to this layer the mel-spectrum frames, which are 80 float values * 2,500 so the in_channels for this conv1d layer should be 80 instead of 1? Why is the output-channels 512?
(as 2,500 is the max decoder steps defined as max_iters in hparams.py)

  (first_conv): Conv1d(1, 512, kernel_size=(1,), stride=(1,))

2) Isn't the input to the wavenet mel-frames of 80 floats? Why is input_type listed as input_type="raw"?

3) Why is the input-channels in (conv1x1c) below 80 and (conv1x1_out) 256? There doesn't seem to anything generating 256-d inputs for (conv1x1_out). What exactly is their inputs? (e.g. is it mel-spectrum frames) Is the wavenet-vocoder generating just 1 float value per 1 input mel-spectrum frame of 80 floats?

    (1): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )

The whole network shows as the following when I print it:

>>>p model
WaveNet(
  (first_conv): Conv1d(1, 512, kernel_size=(1,), stride=(1,))
  (conv_layers): ModuleList(
    (0): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (1): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (2): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(4,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (3): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(8,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (4): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(16,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (5): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(32,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (6): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (7): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (8): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(4,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (9): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(8,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (10): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(16,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (11): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(32,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (12): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (13): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (14): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(4,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (15): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(8,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (16): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(16,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (17): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(32,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (18): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (19): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(2,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (20): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(4,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (21): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(8,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (22): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(16,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
    (23): ResidualConv1dGLU(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(32,))
      (conv1x1c): Conv1d(80, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_out): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (conv1x1_skip): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    )
  )
  (last_conv_layers): ModuleList(
    (0): ReLU(inplace=True)
    (1): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
    (2): ReLU(inplace=True)
    (3): Conv1d(256, 30, kernel_size=(1,), stride=(1,))
  )
  (upsample_conv): ModuleList(
    (0): ConvTranspose2d(1, 1, kernel_size=(3, 4), stride=(1, 4), padding=(1, 0))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(1, 1, kernel_size=(3, 4), stride=(1, 4), padding=(1, 0))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(1, 1, kernel_size=(3, 4), stride=(1, 4), padding=(1, 0))
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(1, 1, kernel_size=(3, 4), stride=(1, 4), padding=(1, 0))
    (7): ReLU(inplace=True)
  )
)
r9y9 commented 4 years ago

I believe you can find answers from WaveNet paper(s). Please check https://github.com/r9y9/wavenet_vocoder#references.