EmilianPostolache / stable-audio-controlnet

Fine-tune Stable Audio Open with DiT ControlNet.
Other
171 stars 4 forks source link

Stable Audio ControlNet

Fine-tune Stable Audio Open with DiT ControlNet. On 16GB VRAM GPU you can use adapter of 20% the size of the full DiT with bs=1 and mixed fp16 (50% with 24GB VRAM GPU). Work in progress, code is provided as-is!

alt text

Examples (MoisesDB model): https://drive.google.com/drive/folders/1C7Q1hvUXH-0eAC-XaH6KtmHBpQYzo36a?usp=drive_link

Quick start

To initialize ControlNet based on stable-audio-open checkpoint, retaining depth_factor layers (e.g., depth_factor = 0.2 retains 20% of layers in DiT, int(0.2 * 24) = 5 layers), conditioned on audio, call:

from main.controlnet.pretrained import get_pretrained_controlnet_model

model, model_config = get_pretrained_controlnet_model("stabilityai/stable-audio-open-1.0", controlnet_types=["audio"], depth_factor=0.2)

For training first disable training on frozen structures:

model.model.model.requires_grad_(False)
model.conditioner.requires_grad_(False)
model.conditioner.eval()
model.pretransform.requires_grad_(False)
model.pretransform.eval()

And pass only ControlNet adapter to optimizer:

params = list(self.model.model.controlnet.parameters())
optimizer = torch.optim.AdamW(params, lr)

During training pass the conditioning audio y in the conditioning dictionary

x, y, prompts, start_seconds, total_seconds = batch
...
# obtain noised_input from x at time t
...
# we pass controlnet conditioning with id "audio" in the conditioning dictionary
conditioning = [{"audio": y[i:i+1],
                 "prompt": prompts[i], 
                 "seconds_start": start_seconds[i],
                 "seconds_total": total_seconds[i]} for i in range(y.shape[0])]
output = model(x=noised_inputs, 
               t=t.to(device),
               cond=model.conditioner(conditioning),
               cfg_dropout_prob=cfg_dropout_prob,
               device=device))
# compute diffusion loss with output

For inference call for example:

from stable_audio_tools.inference.generation import generate_diffusion_cond

# define conditioning dictionary
assert batch_size == len(conditioning)

output = generate_diffusion_cond(
                model,
                steps=steps,
                batch_size=batch_size,
                cfg_scale=7.0,
                conditioning=conditioning,
                sample_size=sample_size,
                sigma_min=0.3,
                sigma_max=500,
                sampler_type="dpmpp-3m-sde",
                device="cuda"
        )

Implementation Details

The ControlNet architecture is implemented by defining two classes (in diffusion.py):

TODO

Demo

In the following we detail training a model for music source accompaniment generation on MusDB (audio ControlNet conditioning). Another examples with envelope (filtered RMS envelope) and chroma (chromogram mask for pitch control) controls are available as well. Audio and envelope are tested and working well. Chroma control needs experimental confirmation.

alt text

Setup

First install the requirements. torchaudio has to be installed as the nightly build. You can do it with pip3 install --pre torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118. Afterwards copy .env.tmp as .env and replace with your own variables (example values are random):

DIR_LOGS=/logs
DIR_DATA=/data

# Required if using wandb logger
WANDB_PROJECT=audioproject
WANDB_ENTITY=johndoe
WANDB_API_KEY=a21dzbqlybbzccqla4txa21dzbqlybbzccqla4tx

Afterwards, log in on Hugginface with huggingface-cli login using personal token in order to be able to download Stable Audio Open weights.

For the demo, since we are using mp3 version of musdb, it is also necessary to have libsndfile installed. You can install it with conda install conda-forge::libsndfile.

Dataset (MusDB18HQ)

First download the sharded version of MusDB18HQ from https://drive.google.com/drive/folders/1bwiJbRH_0BsxGFkH0No-Rg_RHkVR2gc7?usp=sharing and put the files test.tar and train.tar inside data/musdb18hq/.

Train

For training run

PYTHONUNBUFFERED=1 TAG=musdb-controlnet-audio python train.py exp=train_musdb_controlnet_audio \
datamodule.train_dataset.path=data/musdb18hq/train.tar \ 
datamodule.val_dataset.path=data/musdb18hq/test.tar

For resuming training with checkpoint.ckpt stored in ckpts run:

PYTHONUNBUFFERED=1 TAG=musdb-controlnet-audio python train.py exp=train_musdb_controlnet_audio \
datamodule.train_dataset.path=data/musdb18hq/train.tar \ 
datamodule.val_dataset.path=data/musdb18hq/test.tar \
+ckpt=ckpts/checkpoint.ckpt

Inference

Pre-trained checkpoints

Checkpoints for audio conditioned ControlNet:

Inference Notebook

For performing inference with model trained on MusDB (MoisesDB), you can run notebook/inference_musdb_audio_large.ipynb (notebook/inference_moisesdb_audio_large.ipynb). The notebook expects the checkpoints to be found in the folder ckpts/musdb-audio (ckpts/moisesdb-audio). Inference can be performed with a 16GB VRAM GPU.

Credits