google-deepmind / sonnet

TensorFlow-based neural network library
https://sonnet.dev/
Apache License 2.0
9.77k stars 1.3k forks source link

ValueError when using VectorQuantizerEMA with a Distributed Strategy #209

Closed evanatyourservice closed 3 years ago

evanatyourservice commented 3 years ago

Hello! Thank you for this great library and the vq vae implementation.

While using the ema vqvae with a distributed strategy (in this case TPU), I get a value error when reducing the grads across accelerators because one of the grads is None. This is from the self.embeddings variable. The problem is fixed when trainable is to False:

    self.embeddings = tf.Variable(
        initializer(embedding_shape, dtype), trainable=False, name='embeddings')

Please let me know if this is not the right way to fix this problem. If it is, I'd be happy to submit a pull request for the ema and non-ema vqvaes, or you all can, let me know. Thanks, Evan

tomhennigan commented 3 years ago

Thanks for filing this! Yes I think it would probably be appropriate to set trainable=False for VectorQuantizerEMA.embeddings (since this bit is used in TensorFlow to indicate whether a mutable variable should be updated by an optimizer, and in VQ-EMA we update the embeddings as part of the forward pass when training).

I think we want to leave VectorQuantizer.embeddings as is (since IIRC this is updated by the optimizer).

If you don't mind I'll send a patch for this and ask someone who knows more about the network than me to review just to make sure we're not missing something subtle. Should land within a week.

evanatyourservice commented 3 years ago

Awesome, sounds good, thank you! Ok yeah I wasn't as familiar with the non-EMA version so I just guessed on that one