kyutai-labs / moshi

Apache License 2.0
6.77k stars 532 forks source link

Question for streaming audio generation #118

Closed sh-lee-prml closed 1 month ago

sh-lee-prml commented 1 month ago

Due diligence

Topic

The PyTorch implementation

Question

Thanks for nice work.

I'm trying to test streaming inference and parallel inference of Mimi.

The samples from parallel inference is good but I found that the generated samples with streaming inference contain some noise when I use a single frame for streaming inference.

I followed the example of https://github.com/kyutai-labs/moshi/tree/main/moshi#api as below:

        with torch.no_grad():
            codes = mimi.encode(audio.unsqueeze(1))  # [B, K = 8, T]
            out_wav_chunks = []
            for idx in range(codes.shape[-1]):
                wav_chunk = mimi.decode(codes[:, :, idx:idx+1])
                out_wav_chunks.append(wav_chunk)

            out_wav = torch.cat(out_wav_chunks, dim=-1)
            resynthesis_audio = out_wav

and is it right the input of streaming decoder is a single frame token? [B, 8, T=1]?

If not, I hope to know the proper frame size of T for streaming inference.

Thanks for nice work again.

Update

I have checked that Streaming Encoding has almost same quality with parallel encoding.

I may have a mistake about streaming decoding. I think Mimi also decode the token by streaming waveform generation.

my last question is

Only Moshi decode the next token by streaming inference. Then, the chunks of tokens are decoded to waveform signal parallel. Is it right?

LaurentMazare commented 1 month ago

Could you try adding a "streaming context" for the decoding? E.g. something like:

    with mimi.streaming(batch_size=1):
            out_wav_chunks = []
            for idx in range(codes.shape[-1]):
                wav_chunk = mimi.decode(codes[:, :, idx:idx+1])
                out_wav_chunks.append(wav_chunk)
sh-lee-prml commented 1 month ago

Thanks!

I missed adding that line. The samples are very good now :)