kaiidams / soundstream-pytorch

Unofficial SoundStream implementation of Pytorch with training code and 16kHz pretrained checkpoint
MIT License
54 stars 10 forks source link

The implementation of the Residual Vector Quantizer algorithm does not correspond to the description in the original paper. #2

Closed arishov1 closed 8 months ago

arishov1 commented 8 months ago

Thanks for sharing your code. From my understanding, to implement the algorithm described in this paper, it appears necessary to modify your code as follows:

r = input.type_as(self.running_mean).detach()
quantized = torch.zeros_like(input).type_as(self.running_mean)

        with torch.no_grad():
            for i in range(n):
                w = self.weight[i]
                # r: [..., num_embeddings]
                dist = torch.cdist(r, w)
                k = torch.argmin(dist, axis=-1)
                codes.append(k)
                self._update_averages(i, r, k)
                q = F.embedding(k, w)
                r = r - q
                quantized = quantized + q
        # quantized = input - r , we dont need this line of code anymore
commitment_loss = torch.mean(torch.square(input - quantized.detach()))
self.weight.data[:] = self.running_mean / torch.unsqueeze(self.eps + self.code_count, axis=-1)
return quantized, torch.stack(codes, input.ndim - 1), commitment_loss
kaiidams commented 8 months ago

Thank you for your feedback. After the loop, quantized = q0 + q1 + ... + qn in your code, quantized = input - r = input - (input - q0 - q1 - ... - qn) = q0 + q1 + ... + qn in the original. I think the results are the same except buffer needed to keep the temporary data.