kyutai-labs / moshi

Apache License 2.0
6.77k stars 532 forks source link

The output audios usually with meaningless content like "day, how are you doing?", "Hello, how are you today? " when offline inference #149

Closed UltraEval closed 6 days ago

UltraEval commented 2 weeks ago

Due diligence

Topic

The PyTorch implementation

Question

Thanks for nice work.

I write code for answer audio file following the example of https://github.com/kyutai-labs/moshi/tree/main/moshi#api .

like this

import torch

from moshi.models import loaders, LMGen
from tqdm import tqdm
import torchaudio
import scipy.io.wavfile as wavfile
import sentencepiece

def load_wav(audio_file, target_sr):
    sample_rate, speech = wavfile.read(audio_file)
    speech = torch.from_numpy(speech).float()
    speech = speech.unsqueeze(0)  # [B, T]
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech

class MoshiInference:

    def __init__(self, model_path, device='cuda'):
        mimi_weight = '{}/{}'.format(model_path,loaders.MIMI_NAME)
        self.mimi = loaders.get_mimi(mimi_weight, device=device)
        self.mimi.set_num_codebooks(8)

        moshi_weight = '{}/{}'.format(model_path,loaders.MOSHI_NAME)
        self.moshi = loaders.get_moshi_lm(moshi_weight, device=device)
        self.lm_gen = LMGen(self.moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.

        text_tokenizer_path = '{}/{}'.format(model_path,loaders.TEXT_TOKENIZER_NAME)
        self.text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer_path)

    def encode(self, audio_path):
        wav = load_wav(audio_path,24000)
        print('Input length:', wav.shape)
        current_length = wav.shape[-1]
        target_length = ((current_length - 1) // 1920 + 1) * 1920
        if current_length < target_length:
            padding = target_length - current_length
            wav = torch.nn.functional.pad(wav, (0, padding))
        wav = wav.unsqueeze(0)  # [B, T]
        print('Input length:', wav.shape)
        with torch.no_grad():
            # Supports streaming too.
            frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
            all_codes = []
            with self.mimi.streaming(batch_size=1):
                for offset in range(0, wav.shape[-1], frame_size):
                    frame = wav[:, :, offset: offset + frame_size]
                    codes = self.mimi.encode(frame.cuda())
                    assert codes.shape[-1] == 1, codes.shape
                    all_codes.append(codes)
            return all_codes

    def infer(self, audio_path):
        all_codes = self.encode(audio_path)
        wav_chunks = []
        text_chunks = []

        with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
            for idx, code in tqdm(enumerate(all_codes)):
                tokens_out = self.lm_gen.step(code)
                if tokens_out is not None:
                    wav_chunk = self.mimi.decode(tokens_out[:, 1:])
                    wav_chunks.append(wav_chunk)

                    text_token = tokens_out[0,0,0].item()
                    if text_token not in (0, 3):
                        text_chunk = self.text_tokenizer.id_to_piece(text_token)
                        text_chunk = text_chunk.replace("▁", " ")
                        text_chunks.append(text_chunk)

        out_wav = torch.cat(wav_chunks, dim=-1)
        out_wav = out_wav.squeeze(0)  # [T]
        print('Output length:', out_wav.shape)
        print('Output text:', ''.join(text_chunks))
        torchaudio.save('output.wav', out_wav.cpu(), 24000)

if __name__ == '__main__':
    infer = MoshiInference('model/kyutai_moshiko-pytorch-bf16/')
    audio_path = 'test.wav'
    infer.infer(audio_path)

sometime, it output is normal:

Output length: torch.Size([1, 211200])
Output text:  Hello, how are you today? Oh, that's a tough one. There are so many to choose from. What do you think?

but too many times, it just output:

111it [00:05, 21.25it/s]
Output length: torch.Size([1, 211200])
Output text:  Hey what's up?

Output length: torch.Size([1, 211200])
Output text:  Hello, how are you today?

Output text:  day, how are you doing?

the above output is from same input file:

test.wav.zip

It's kind of weird, can you check it?

LaurentMazare commented 2 weeks ago

It's certainly hard to tell what is going on, one thing that might be worth checking is that your audio starts immediately and the model always start the conversations with some incipit like "Hello, how are you doing?" so it may be better to delay your audio for a bit until the model has finished with its first sentence (maybe just hardcoding some delay like 5s would be good enough).

UltraEval commented 6 days ago

I tried two methods to process the audio file:

from pydub import AudioSegment
from pydub.playback import play
silence_duration = 5000
silence = AudioSegment.silent(duration=silence_duration)
original_audio = AudioSegment.from_wav("test.wav")
padded_audio = silence + original_audio
padded_audio.export("padded_audio.wav", format="wav")

However, the output remains the same as before. So be it.