facebookresearch / encodec

State-of-the-art deep learning based audio codec supporting both mono 24 kHz audio and stereo 48 kHz audio.
MIT License
3.48k stars 304 forks source link

Channel-mismatch `RuntimeError` when extracting embedding with 24kHz model #24

Closed rakuzen25 closed 1 year ago

rakuzen25 commented 1 year ago

🐛 Bug Report

Following the "Extracting discrete representations" section in README, I tried to extract the encoded embedding myself. However, running the exact code snippet gave me an error: RuntimeError: Given groups=1, weight of size [32, 1, 7], expected input[1, 2, 144006] to have 1 channels, but got 2 channels instead.

To Reproduce

from encodec import EncodecModel
from encodec.utils import convert_audio

import torchaudio
import torch

# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)

# Load and pre-process the audio waveform
wav, sr = torchaudio.load("test.wav")
wav = wav.unsqueeze(0)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)

# Extract discrete codes from EnCodec
with torch.no_grad():
    encoded_frames = model.encode(wav)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]

where test.wav is any WAV file. I tried with one on the sample page.

Expected behavior

I should be able to get the representation in [B, n_q, T] as described in the code itself.

Actual Behavior

Full traceback:

Traceback (most recent call last):
  File "/home/ubuntu/test.py", line 18, in <module>
    encoded_frames = model.encode(wav)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/encodec/model.py", line 144, in encode
    encoded_frames.append(self._encode_frame(frame))
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/encodec/model.py", line 161, in _encode_frame
    emb = self.encoder(x)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/encodec/modules/seanet.py", line 144, in forward
    return self.model(x)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/encodec/modules/conv.py", line 210, in forward
    return self.conv(x)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/encodec/modules/conv.py", line 120, in forward
    x = self.conv(x)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1208, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 313, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/ubuntu/miniconda3/envs/dev/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 309, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 1, 7], expected input[1, 2, 144006] to have 1 channels, but got 2 channels instead

Your Environment

rakuzen25 commented 1 year ago

Forgot to add in the original comment - 48kHz model works fine (i.e. if I do model = EncodecModel.encodec_model_48khz()).

rakuzen25 commented 1 year ago

I think I've found the problem. wav = wav.unsqueeze(0) should happen after the convert_audio step because convert_audio assumes wav[0] to be the channel count, when in fact after unsqueezing it's the "sample count".

https://github.com/facebookresearch/encodec/blob/194329839fd812433992272fc5e7a889176e6fd1/encodec/utils.py#L79-L82

Should I make a PR for this?

adefossez commented 1 year ago

Exactly sorry about that, I fixed the convert_audio fonction to be more robust ! I also updated the README to swap the order of the two lines for people who would be using the older version. Thanks for reporting !!

rakuzen25 commented 1 year ago

No worries, thanks for the fix!