Open zhxgj opened 4 years ago
i agree with you upfloor. it is so weird.
I found
ctx.needs_input_grad[1]
isFalse
during training VQ-VAE. Is this correct, and does it mean the embedding of the codebook does not update during training?
This part of code has not been executed! But I printed the "model.codebook.embedding.weight.data" and found that this part will be updated!
Actually, ctx.needs_input_grad[0] and ctx.needs_input_grad[1] are set to true and false alternatively. For the 1st step, ctx.needs_input_grad[0] is true and ctx.needs_input_grad[1] is false. For the 2nd step, ctx.needs_input_grad[0] becomes false and ctx.needs_input_grad[1] becomes true. For the 3rd step, ctx.needs_input_grad[0] is true and ctx.needs_input_grad[1] is false. So on and so forth...
This setting is reasonable because there are two "agents", namely codebook and autoencoder, updating w.r.t. to different parts of the loss function.
Actually, ctx.needs_input_grad[0] and ctx.needs_input_grad[1] are set to true and false alternatively. For the 1st step, ctx.needs_input_grad[0] is true and ctx.needs_input_grad[1] is false. For the 2nd step, ctx.needs_input_grad[0] becomes false and ctx.needs_input_grad[1] becomes true. For the 3rd step, ctx.needs_input_grad[0] is true and ctx.needs_input_grad[1] is false. So on and so forth...
This setting is reasonable because there are two "agents", namely codebook and autoencoder, updating w.r.t. to different parts of the loss function.
I debug the code and find that ctx.needs_input_grad[1] is always false rather than being set to true and false alternatively. A basic fact is that if a variable $A$ doesn't require gradient, it doesn't mean that it will not be updated during optimization. The attirbute requires_grad describes whether its gradient should be calculated. In other word, whether other variables calculated by $A$ should be updated rather than updating $A$ itself!
Therefore, though ctx.needs_input_grad[1].requires_grad is always False, the codebook can still be updated.
I found
ctx.needs_input_grad[1]
isFalse
during training VQ-VAE. Is this correct, and does it mean the embedding of the codebook does not update during training?https://github.com/ritheshkumar95/pytorch-vqvae/blob/8d123c0d043bebc8734d37785dd13dd20e7e5e0e/functions.py#L53