kyutai-labs / moshi

Apache License 2.0
2.12k stars 122 forks source link

Difference in encoded and decoded wav by mimi when using real audio instead of torch.randn #69

Closed karynaur closed 7 hours ago

karynaur commented 12 hours ago

Backend impacted

The PyTorch implementation

Operating system

Linux

Hardware

CPU

Description

from huggingface_hub import hf_hub_download
import torch

from moshi.models import loaders, LMGen

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(32)  # up to 32 for mimi, but limited to 8 for moshi.

wav = torch.randn(1, 1, mimi.sample_rate * 10)  # should be [B, C=1, T]

with torch.no_grad():
    codes = mimi.encode(wav)  # [B, K = 8, T]
    decoded = mimi.decode(codes)

the output for this ^ is: decoded.shape, wav.shape = (torch.Size([1, 1, 240000]), torch.Size([1, 1, 240000])) Works perfect!

from huggingface_hub import hf_hub_download
import torch

from moshi.models import loaders, LMGen

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(32)  # up to 32 for mimi, but limited to 8 for moshi.

# wav = torch.randn(1, 1, mimi.sample_rate * 10)  # should be [B, C=1, T]
import torchaudio
arr, org_sr = torchaudio.load('/hdd5/intern_aditya/data/vox/vox1/speaker/id10001/1zcIwhmdeo4_00001.wav')
wav = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=mimi.sample_rate).reshape(1, 1, -1)

# take only the first 10 seconds
wav = wav[:, :, :mimi.sample_rate * 10]

# normalize it with var 0 and std 1
wav = (wav - wav.mean()) / wav.std()

with torch.no_grad():
    codes = mimi.encode(wav)  # [B, K = 8, T]
    decoded = mimi.decode(codes)

Output: decoded.shape, wav.shape = (torch.Size([1, 1, 195840]), torch.Size([1, 1, 194882]))

why the difference in sizes and whats the extra information in decoded?

Extra information

Installed with pip install moshi

Environment

Fill in the following information on your system.

If the backend impacted is PyTorch:

If the backend is MLX:

adefossez commented 8 hours ago

Yes this is expected, you need to pad the last frame to have a shape multiple of 1920 (the frame size of Mimi). We cannot do it automatically as we can never know if something is the last frame or not!

adefossez commented 7 hours ago

i have clarified this point in the code snippet https://github.com/kyutai-labs/moshi/tree/main/moshi#api