Hi, I tried to use both codebook loss and commitment loss instead of EMA update, but I was confused about how to use codebook loss.
If 'learable_codebook' is True, then 'commit_quantize' is not detached from 'quantize',
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
commit_quantize = maybe_detach(quantize)
and 'commit_loss' will serve as both codebook loss and commitment loss.
commit_loss = F.mse_loss(commit_quantize, x)
So, just setting 'learable_codebook' as True is all I need for applying codebook loss, or need to implement codebook loss seperately?
Also, if 'commit_loss' serves as codebook loss, then I think 'commit_quantize' should be detached when using in-place update, so that codebook will not be updated twice!
# one step in-place update
if should_inplace_optimize and self.training and not freeze_codebook:
if exists(mask):
loss = F.mse_loss(quantize, x.detach(), reduction = 'none')
loss_mask = mask
if is_multiheaded:
loss_mask = repeat(mask, 'b n -> c (b h) n', c = loss.shape[0], h = loss.shape[1] // mask.shape[0])
loss = loss[loss_mask].mean()
else:
loss = F.mse_loss(quantize, x.detach())
loss.backward()
self.in_place_codebook_optimizer.step()
self.in_place_codebook_optimizer.zero_grad()
# quantize again
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
Hi, I tried to use both codebook loss and commitment loss instead of EMA update, but I was confused about how to use codebook loss.
If 'learable_codebook' is True, then 'commit_quantize' is not detached from 'quantize',
and 'commit_loss' will serve as both codebook loss and commitment loss.
So, just setting 'learable_codebook' as True is all I need for applying codebook loss, or need to implement codebook loss seperately?
Also, if 'commit_loss' serves as codebook loss, then I think 'commit_quantize' should be detached when using in-place update, so that codebook will not be updated twice!