kyutai-labs / moshi

Apache License 2.0
6.77k stars 532 forks source link

runing pytorch implementation always got similar output eg.. what's' going on #155

Open hjc3613 opened 21 hours ago

hjc3613 commented 21 hours ago

Due diligence

Topic

The PyTorch implementation

Question

moshi allways output unexpected answers eg.. Hello, what's going on? here is my scripts:

from huggingface_hub import hf_hub_download
import torch
import os
import librosa
import numpy as np
from tqdm import tqdm
from moshi.moshi.models import loaders, LMGen
import soundfile as sf
import numpy as np
from subprocess import call
import sphn
import sentencepiece
device = torch.device('cpu')
MODEL_PATH = '/opt/ailab_mnt1/LLM_MODELS/moshi/moshika-pytorch-bf16'
mimi_weight = os.path.join(MODEL_PATH, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.
text_tokenizer = sentencepiece.SentencePieceProcessor(os.path.join(MODEL_PATH, loaders.TEXT_TOKENIZER_NAME))

moshi_weight = os.path.join(MODEL_PATH, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device=device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.
# 
def save_as_wav(y, sr, output_path):
    sf.write(output_path, y, sr)

def one_round_test(audio_path):

    wav, sample_sr = sphn.read(audio_path)
    sample_rate = mimi.sample_rate
    wav = sphn.resample(
            wav, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
        )

    wav  = torch.from_numpy(wav[None, :])
    mimi.to(device)
    # wave, sample_rate = torch.randn(1, 1, 24000 * 10), mimi.sample_rate
    with torch.no_grad():
        codes = mimi.encode(wav.to(device))  # [B, K = 8, T]
        # decoded = mimi.decode(codes)
        # save_as_mp3(decoded.numpy().squeeze(), sample_rate, audio_path.replace('.mp3', '_decoded.wav'))
        # Supports streaming too.
        frame_size = int(mimi.sample_rate / mimi.frame_rate)
        all_codes = []
        # with mimi.streaming(batch_size=1):
        for offset in tqdm(range(0, wav.shape[-1], frame_size), desc='mimi encoding...'):
                frame = wav[:, :, offset: offset + frame_size]
                if frame.shape[-1] < frame_size:
                    continue
                codes = mimi.encode(frame.to(device))
                assert codes.shape[-1] == 1, codes.shape
                all_codes.append(codes)

    # mimi.cuda()
    # moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)

    out_wav_chunks = []
    main_text = []
    # Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
    with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    # with torch.no_grad():
        for idx, code in tqdm(enumerate(all_codes), desc='lm_gen steping...', total=len(all_codes)):
            tokens_out = lm_gen.step(code.to(device))
            # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
            if tokens_out is not None:
                wav_chunk = mimi.decode(tokens_out[:, 1:])
                out_wav_chunks.append(wav_chunk)
                text_token = tokens_out[:, 0, 0][0].item()
                if text_token not in (0, 3):
                    _text = text_tokenizer.id_to_piece(text_token)
                    _text = _text.replace("▁", " ")
                    main_text.append(_text)
            # print(idx, end='\r')
    out_wav = torch.cat(out_wav_chunks, dim=-1)
    save_as_wav(out_wav.squeeze().cpu().numpy(), sample_rate, audio_path.replace('.wav', '_answer.wav'))
    print('generated_text:')
    print(''.join(main_text))

if __name__ == '__main__':
    wave_root = 'wave_data/tts_res'
    wave_files = os.listdir(wave_root)
    for file in wave_files:
        audio_path = os.path.join(wave_root, file)
        one_round_test(audio_path)

my audio files are generated using tts with the following questions:

questions = [
    'At what temperature does water boil?',
    'What is the largest organ in the human body?',
    'Which is the largest planet in the solar system?',
    'What is the approximate speed of light?',
    'Who discovered the double helix structure of DNA?',
    'What is the deepest ocean trench on Earth?',
    'What is the normal human body temperature?',
    'What is the first element in the periodic table?',
    'n which year did humans first land on the moon?',
    'What is the approximate total length of the Great Wall?',
    'What is the highest mountain on Earth?',
    'How many years ago did dinosaurs go extinct?',
    'What is the smallest bone in the human body?',
    'What is the longest river in the world?',
    'How many times does the human heart beat per minute on average?',
]

when I using the first py file run moshi, I got this result: image

note: I run the python script in cpu mode,but gpu mode also tested in online mode and got the very similar wav output