tomlepaine / fast-wavenet

Speedy Wavenet generation using dynamic programming :zap:
GNU General Public License v3.0
1.76k stars 307 forks source link

if filter width =3, how to do fast inference? #19

Open weixsong opened 6 years ago

weixsong commented 6 years ago

In the new paper, Google use filter width =3 to increase the receptive field.

Then how could we do inference with filter width 3? My idea is use to Queue, because the dilation is still 2 times increased, the first Queue is used to store the first half of middle value, and the second Queue is used to store the second half middle value. Output of first Queue then be enqueued into the second Queue.

such as:

        current_state = q.dequeue()
        push = q.enqueue([current_layer])
        init_ops.append(init)
        push_ops.append(push)

        pre_state = None
        if self.filter_width == 3:
            q2 = tf.FIFOQueue(
                 1,
                 dtypes=tf.float32,
                 shapes=(self.batch_size, self.quantization_channels))

            init2 = q2.enqueue_many(tf.zeros((1, self.batch_size, self.quantization_channels)))

            pre_state = q2.dequeue()
            push2 = q2.enqueue([current_state])

            init_ops2.append(init2)
            push_ops2.append(push2)

        if self.filter_width == 2:
            current_layer = self._generator_causal_layer(
                            current_layer, current_state)
        if self.filter_width == 3:
            current_layer = self._generator_causal_layer(
                            current_layer, current_state, pre_state)

...
        with tf.name_scope('dilated_stack'):
            for layer_index, dilation in enumerate(self.dilations):
                with tf.name_scope('layer{}'.format(layer_index)):

                    q = tf.FIFOQueue(
                        dilation,
                        dtypes=tf.float32,
                        shapes=(self.batch_size, self.residual_channels))
                    init = q.enqueue_many(
                        tf.zeros((dilation, self.batch_size,
                                  self.residual_channels)))

                    current_state = q.dequeue()
                    push = q.enqueue([current_layer])
                    init_ops.append(init)
                    push_ops.append(push)

                    pre_state = None
                    if self.filter_width == 3:
                        q2 = tf.FIFOQueue(
                             dilation,
                             dtypes=tf.float32,
                             shapes=(self.batch_size, self.residual_channels))

                        init2 = q2.enqueue_many(tf.zeros((dilation, self.batch_size, self.residual_channels)))

                        pre_state = q2.dequeue()
                        push2 = q2.enqueue([current_state])

                        init_ops2.append(init2)
                        push_ops2.append(push2)

                    output, current_layer = self._generator_dilation_layer(
                        current_layer, current_state, layer_index, dilation,
                        global_condition_batch, local_condition, pre_state)
                    outputs.append(output)

is that make sense?

KingStorm commented 6 years ago

i think that's the idea.

twidddj commented 6 years ago

@weixsong Hi, We have considered this issue. You can find the method in our repository