Closed evanatyourservice closed 3 years ago
Thanks for filing this! Yes I think it would probably be appropriate to set trainable=False
for VectorQuantizerEMA.embeddings
(since this bit is used in TensorFlow to indicate whether a mutable variable should be updated by an optimizer, and in VQ-EMA we update the embeddings as part of the forward pass when training).
I think we want to leave VectorQuantizer.embeddings
as is (since IIRC this is updated by the optimizer).
If you don't mind I'll send a patch for this and ask someone who knows more about the network than me to review just to make sure we're not missing something subtle. Should land within a week.
Awesome, sounds good, thank you! Ok yeah I wasn't as familiar with the non-EMA version so I just guessed on that one
Hello! Thank you for this great library and the vq vae implementation.
While using the ema vqvae with a distributed strategy (in this case TPU), I get a value error when reducing the grads across accelerators because one of the grads is None. This is from the self.embeddings variable. The problem is fixed when trainable is to False:
Please let me know if this is not the right way to fix this problem. If it is, I'd be happy to submit a pull request for the ema and non-ema vqvaes, or you all can, let me know. Thanks, Evan