kyutai-labs / moshi

Apache License 2.0
6.77k stars 532 forks source link

Value Differences between Streaming Mimi and Non-streaming Mimi #145

Closed shinshoji01 closed 2 weeks ago

shinshoji01 commented 3 weeks ago

Due diligence

Topic

The PyTorch implementation

Question

Hi, I'm using Mimi's semantic tokens in my research. I want to extract the semantic tokens in advance. Since extracting semantic tokens in a streaming way is time-consuming, I want to do it in parallel. However, when I extracted tokens in a streaming and non-streaming way, I got slightly different values for those two. Is there any way to obtain the same values? Can I obtain the same value as a streaming way in parallel?

How to reproduce this issue:

from huggingface_hub import hf_hub_download
import torch
import glob
import librosa
import numpy as np
from moshi.models import loaders

device = "cuda"
sr = 24000
audiopath = "sample.wav"

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.

audio, _ = librosa.load(path=audiopath, sr=sr, mono=True)
wav = torch.tensor(audio).reshape(1, 1, -1)
wav = wav.to(device)
frame_size = int(mimi.sample_rate / mimi.frame_rate)
padnum = frame_size-wav.shape[-1]%frame_size
if padnum>0:
    wav = torch.nn.functional.pad(wav, (0, padnum), "constant", 0)
with torch.no_grad():
    ### Streaming ###
    streaming_codes = []
    with mimi.streaming(batch_size=1):
        for offset in range(0, wav.shape[-1], frame_size):
            frame = wav[:, :, offset: offset + frame_size]
            framecodes = mimi.encode(frame)
            assert framecodes.shape[-1] == 1, framecodes.shape
            streaming_codes.append(framecodes.cpu().numpy())
    streaming_codes = np.concatenate(streaming_codes, axis=2)[0]

    ### Non-streaming ###
    codes = mimi.encode(wav).cpu().numpy()[0]

print((streaming_codes==codes).mean()) # 0.976986... instead of 1.00

bl = streaming_codes[0]!=codes[0]
print(np.arange(len(bl))[bl]) # display the indices with different values
# [ 151  159  181  226  248  471  597  610  650  682  768  802  863  900
#   906  908  930 1072 1121 1150 1179 1185 1192 1505 1519 1561 1620 1742
#  1934 1967 2012 2245 2379 2408 2430 2559 2584 2704 2708 2712 2754 2772
#  2927 3069 3103 3186 3353 3370 3376 3393 3584 3588 3722 3728 3954 3970
#  4048 4154 4165 4240 4242 4373 4556 4670 4737 4810 4867 4899 4940 4954
#  5029 5043 5098 5200 5277 5288 5380 5382 5561 5721 5788 5888 6094 6196
#  6237 6341 6489 6620 6763 6790 6793 6890 7099 7188 7227 7298 7505 7654
#  7692 7708 7804 7853 7942 8008 8009 8039 8050 8101 8218 8225 8507 8663
#  8672 8798]
LaurentMazare commented 3 weeks ago

The streaming codes should be roughly in line with the non-streaming ones and I guess the difference you're seeing is likely due to numerical instability around the 1d convolution implementations (I just gave a try at this sample file) and got no differences on the semantic tokens but got some differences at higher levels. I wouldn't know of an easy way to get exactly the same values.

shinshoji01 commented 2 weeks ago

The streaming codes should be roughly in line with the non-streaming ones and I guess the difference you're seeing is likely due to numerical instability around the 1d convolution implementations (I just gave a try at this sample file) and got no differences on the semantic tokens but got some differences at higher levels. I wouldn't know of an easy way to get exactly the same values.

OK, thank you very much. I will use a non-streaming version.