lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.45k stars 266 forks source link

Why is Encodec only encoding 1 frame? #275

Closed sivannavis closed 5 months ago

sivannavis commented 5 months ago

Hi, thanks for the code! When I'm running naturalspeech2 with Encodec from this repo, I found these codes ensures that number of time frames equals 1 but don't quite understand why it has to be the case. Could you help me understand this?

I checked in https://github.com/lucidrains/audiolm-pytorch/blob/42da76b644eb3e16559382333488fd0fdd719611/audiolm_pytorch/encodec.py#L117 the encoded_frames is always a list of length 1. And this is because when entering the model in Encodec, inside encode function segment_length is always set to the original length of the whole audio. I'm wondering if it's a problem of the initializing in EcodecWrapper? Thanks!

    def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
        """Given a tensor `x`, returns a list of frames containing
        the discrete encoded codes for `x`, along with rescaling factors
        for each segment, when `self.normalize` is True.

        Each frames is a tuple `(codebook, scale)`, with `codebook` of
        shape `[B, K, T]`, with `K` the number of codebooks.
        """
        assert x.dim() == 3
        _, channels, length = x.shape
        assert channels > 0 and channels <= 2
        segment_length = self.segment_length
        if segment_length is None:
            segment_length = length
            stride = length
        else:
            stride = self.segment_stride  # type: ignore
            assert stride is not None

        encoded_frames: tp.List[EncodedFrame] = []
        for offset in range(0, length, stride):
            frame = x[:, :, offset: offset + segment_length]
            encoded_frames.append(self._encode_frame(frame))
        return encoded_frames
sivannavis commented 5 months ago

Oh I see, the segment is different from the model's frame rate, so even if the length of the list is 1, the code is already in shape of batch quantizer time frame