lucidrains / vector-quantize-pytorch

Vector (and Scalar) Quantization, in Pytorch
MIT License
2.12k stars 179 forks source link

learnable_codebook and in-place optimization #97

Open daraha76 opened 6 months ago

daraha76 commented 6 months ago

Hi, I tried to use both codebook loss and commitment loss instead of EMA update, but I was confused about how to use codebook loss.

If 'learable_codebook' is True, then 'commit_quantize' is not detached from 'quantize',

            maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity

            commit_quantize = maybe_detach(quantize)

and 'commit_loss' will serve as both codebook loss and commitment loss.

commit_loss = F.mse_loss(commit_quantize, x)

So, just setting 'learable_codebook' as True is all I need for applying codebook loss, or need to implement codebook loss seperately?

Also, if 'commit_loss' serves as codebook loss, then I think 'commit_quantize' should be detached when using in-place update, so that codebook will not be updated twice!

        # one step in-place update

        if should_inplace_optimize and self.training and not freeze_codebook:

            if exists(mask):
                loss = F.mse_loss(quantize, x.detach(), reduction = 'none')

                loss_mask = mask
                if is_multiheaded:
                    loss_mask = repeat(mask, 'b n -> c (b h) n', c = loss.shape[0], h = loss.shape[1] // mask.shape[0])

                loss = loss[loss_mask].mean()

            else:
                loss = F.mse_loss(quantize, x.detach())

            loss.backward()
            self.in_place_codebook_optimizer.step()
            self.in_place_codebook_optimizer.zero_grad()

            # quantize again

            quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)