facebookresearch / encodec

State-of-the-art deep learning based audio codec supporting both mono 24 kHz audio and stereo 48 kHz audio.
MIT License
3.52k stars 304 forks source link

Encoding Long Audio Clips #71

Open aviaefrat opened 1 year ago

aviaefrat commented 1 year ago

I need the EnCodec tokens of long audio clips (hours long). Inputing such files as-is results in cuda OOM. I've seen you "do not try to be smart about long files". Does chunking the long audio files naively (and concatenating the EnCodec tokens post-hoc) produce identical results as inputting an entire file to the model? If not, how should I chunk my audio files?

julien-blanchon commented 6 months ago

Hey @aviaefrat did you find an elegent solution to this ?

foreverhell commented 4 months ago

I have tries to split the long audio files naively and concatenate the EnCodec tokens, but the produce results are not consistent except the first clip. I do not know how to keep them same.

MiniXC commented 1 day ago

Here is a (admittedly not too clean) snippet that I successfully used for encoding/decoding long files.

from pathlib import Path
from argparse import ArgumentParser
import io
import gzip
import tarfile
import shutil

from encodec import EncodecModel
from encodec.compress import compress, decompress
from encodec.utils import convert_audio
import torchaudio
import torch
import numpy as np
from tqdm.auto import tqdm

model = EncodecModel.encodec_model_24khz()

SECOND_CHUNKS = 1

device = None

def compress_encodec(src_path, tgt_path):
    wav, sr = torchaudio.load(src_path)
    wav = convert_audio(wav, sr, model.sample_rate, model.channels)
    # change tgt path to not have a suffix
    tgt_path = Path(tgt_path)
    if "." in tgt_path.name:
        tgt_path = tgt_path.parent / tgt_path.name.split(".")[0]
    if wav.shape[1] > model.sample_rate * SECOND_CHUNKS:
        tgt_path_temp = Path("/tmp") / tgt_path.name
        for i in tqdm(range(0, wav.shape[1], model.sample_rate * SECOND_CHUNKS)):
            wav_part = wav[..., i:i + model.sample_rate * SECOND_CHUNKS]
            if device.type == "cuda":
                wav_part = wav_part.cuda()
            file = compress(model, wav_part)
            with open(tgt_path_temp.with_suffix(f".{i}.ecdc"), "wb") as f:
                f.write(file)
        with tarfile.open(tgt_path.with_suffix(".tar.gz"), "w:gz") as tar:
            for i in range(0, wav.shape[1], model.sample_rate * SECOND_CHUNKS):
                tar.add(tgt_path_temp.with_suffix(f".{i}.ecdc"), arcname=tgt_path_temp.with_suffix(f".{i}.ecdc").name)
        # remove the individual files
        for i in range(0, wav.shape[1], model.sample_rate * SECOND_CHUNKS):
            Path(tgt_path_temp.with_suffix(f".{i}.ecdc")).unlink()
    else:
        if device.type == "cuda":
            wav = wav.cuda()
        file = compress(model, wav)
        with gzip.open(tgt_path, "wb") as f:
            f.write(file)

def decompress_encodec(codes_path, tgt_path):
    if str(codes_path).endswith(".tar.gz"):
        all_files = []
        with tarfile.open(codes_path, "r:gz") as tar:
            tar.extractall(path=Path("/tmp"))
            all_files = tar.getnames()
        all_files = [Path("/tmp") / Path(file) for file in all_files]
        all_files = sorted(all_files, key=lambda x: int(x.name.split(".")[1]))
        all_files = [codes_path.parent / Path(file) for file in all_files]
        wav_parts = []
        for file in tqdm(all_files):
            with open(file, "rb") as f:
                codes = f.read()
            wav = decompress(codes, device)
            wav_parts.append(wav[0])
        wav = torch.cat(wav_parts, dim=1)
        torchaudio.save(tgt_path, wav, model.sample_rate)
    else:
        with open(codes_path, "rb") as f:
            codes = f.read()
        wav = decompress(codes)
        torchaudio.save(tgt_path, wav, model.sample_rate)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("src_path", type=Path)
    parser.add_argument("tgt_path", type=Path)
    parser.add_argument("--bandwidth", type=float, default=6.0)
    parser.add_argument("--decompress", action="store_true")
    parser.add_argument("--device", type=str, default="cuda")
    args = parser.parse_args()
    model.set_target_bandwidth(args.bandwidth)
    device = torch.device(args.device)
    model = model.to(device)
    if args.decompress:
        decompress_encodec(args.src_path, args.tgt_path)
    else:
        compress_encodec(args.src_path, args.tgt_path)