google-deepmind / sonnet

TensorFlow-based neural network library
https://sonnet.dev/
Apache License 2.0
9.79k stars 1.3k forks source link

'quantize' vector for inference #124

Closed wsr692 closed 4 years ago

wsr692 commented 5 years ago

I running inference with the trained vq embedding matrix. In the example ipynb only gives training instructions. I wonder if the edits I made to the source code are valid.

In vqvae.py, quantized vector is defined as below, and

LINE 98
quantized= self.quantize(encoding_indices)

The it then gets updated through the following lines,

LINE 100 to LINE 104
e_latent_loss = tf.reduce_mean((tf.stop_gradient(quantized) - inputs) ** 2)
q_latent_loss = tf.reduce_mean((quantized - tf.stop_gradient(inputs)) ** 2) 
loss = q_latent_loss + self._commitment_cost * e_latent_loss

quantized = inputs + tf.stop_gradient(quantized - inputs)

And it seems that for inference, a non-updated version of this vector must be fed into the decoder instead.

So I made the following edits to the code, to return the non-updated quantize vector as well.

LINE 98
quantized= self.quantize(encoding_indices)
quantized_frozen = quantized
LINE 111
    return {'quantize': quantized,
            'quantize_frozen':quantized_frozen,
            'loss': loss,
            'perplexity': perplexity,
            'encodings': encodings,
            'encoding_indices': encoding_indices,
            'wmatrix':w,}

Thank you!

modanesh commented 4 years ago

As I checked, I found them valid.