Fine-tune Stable Audio Open with DiT ControlNet. On 16GB VRAM GPU you can use adapter of 20% the size of the full DiT with bs=1 and mixed fp16 (50% with 24GB VRAM GPU). Work in progress, code is provided as-is!
Examples (MoisesDB model): https://drive.google.com/drive/folders/1C7Q1hvUXH-0eAC-XaH6KtmHBpQYzo36a?usp=drive_link
To initialize ControlNet based on stable-audio-open
checkpoint, retaining depth_factor
layers (e.g., depth_factor
= 0.2 retains 20% of layers
in DiT, int(0.2 * 24) = 5 layers), conditioned on audio, call:
from main.controlnet.pretrained import get_pretrained_controlnet_model
model, model_config = get_pretrained_controlnet_model("stabilityai/stable-audio-open-1.0", controlnet_types=["audio"], depth_factor=0.2)
For training first disable training on frozen structures:
model.model.model.requires_grad_(False)
model.conditioner.requires_grad_(False)
model.conditioner.eval()
model.pretransform.requires_grad_(False)
model.pretransform.eval()
And pass only ControlNet adapter to optimizer:
params = list(self.model.model.controlnet.parameters())
optimizer = torch.optim.AdamW(params, lr)
During training pass the conditioning audio y
in the conditioning dictionary
x, y, prompts, start_seconds, total_seconds = batch
...
# obtain noised_input from x at time t
...
# we pass controlnet conditioning with id "audio" in the conditioning dictionary
conditioning = [{"audio": y[i:i+1],
"prompt": prompts[i],
"seconds_start": start_seconds[i],
"seconds_total": total_seconds[i]} for i in range(y.shape[0])]
output = model(x=noised_inputs,
t=t.to(device),
cond=model.conditioner(conditioning),
cfg_dropout_prob=cfg_dropout_prob,
device=device))
# compute diffusion loss with output
For inference call for example:
from stable_audio_tools.inference.generation import generate_diffusion_cond
# define conditioning dictionary
assert batch_size == len(conditioning)
output = generate_diffusion_cond(
model,
steps=steps,
batch_size=batch_size,
cfg_scale=7.0,
conditioning=conditioning,
sample_size=sample_size,
sigma_min=0.3,
sigma_max=500,
sampler_type="dpmpp-3m-sde",
device="cuda"
)
The ControlNet architecture is implemented by defining two classes (in diffusion.py
):
DiTControlNetWrapper
DiTWrapper
, which contains a DiffusionTransformer
, with a ControlNetDiffusionTransformer
defined in controlnet.py
. The latter has a structure copied from the DiffusionTransformer
(reducing the number of
layers via a controlnet_depth_factor
) and inputs a controlnet_cond
alongside all inputs that we give to DiffusionTransformer
.
The processed inner layers controlnet_embeds
are returned and given as input to the DiffusionTransformer
. The latter contains a version
of ContinuousTransformer
modified to take as input controlnet_embeds
which are summed across the layers corresponding
to the layers of ControlNetDiffusionTransformer
.ConditionedControlNetDiffusionModelWrapper
DiTControlNetWrapper
and handles the conditioning tensors which are passed to DiTControlNetWrapper
:
we add a new controlnet_cond
type which is the conditioning tensor of the ControlNet adapter.In the following we detail training a model for music source accompaniment generation on MusDB (audio
ControlNet conditioning). Another examples with envelope
(filtered RMS envelope) and chroma
(chromogram mask for pitch control) controls are available as well. Audio and envelope are tested and working well. Chroma control needs experimental confirmation.
First install the requirements. torchaudio
has to be installed as the nightly build. You can do it with pip3 install --pre torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
. Afterwards copy .env.tmp
as .env
and replace with your own variables (example values are random):
DIR_LOGS=/logs
DIR_DATA=/data
# Required if using wandb logger
WANDB_PROJECT=audioproject
WANDB_ENTITY=johndoe
WANDB_API_KEY=a21dzbqlybbzccqla4txa21dzbqlybbzccqla4tx
Afterwards, log in on Hugginface with huggingface-cli login
using personal token in order to be able to download
Stable Audio Open weights.
For the demo, since we are using mp3
version of musdb, it is also necessary to have libsndfile
installed. You can install it with conda install conda-forge::libsndfile
.
First download the sharded version of MusDB18HQ from https://drive.google.com/drive/folders/1bwiJbRH_0BsxGFkH0No-Rg_RHkVR2gc7?usp=sharing
and put the files test.tar
and train.tar
inside data/musdb18hq/
.
For training run
PYTHONUNBUFFERED=1 TAG=musdb-controlnet-audio python train.py exp=train_musdb_controlnet_audio \
datamodule.train_dataset.path=data/musdb18hq/train.tar \
datamodule.val_dataset.path=data/musdb18hq/test.tar
For resuming training with checkpoint.ckpt
stored in ckpts
run:
PYTHONUNBUFFERED=1 TAG=musdb-controlnet-audio python train.py exp=train_musdb_controlnet_audio \
datamodule.train_dataset.path=data/musdb18hq/train.tar \
datamodule.val_dataset.path=data/musdb18hq/test.tar \
+ckpt=ckpts/checkpoint.ckpt
Checkpoints for audio conditioned ControlNet:
exp/train_musdb_controlnet_audio_large.yaml
.exp/train_moisesdb_controlnet_audio_large.yaml
.For performing inference with model trained on MusDB (MoisesDB), you can run notebook/inference_musdb_audio_large.ipynb
(notebook/inference_moisesdb_audio_large.ipynb
). The notebook expects the checkpoints to be found in the folder ckpts/musdb-audio
(ckpts/moisesdb-audio
). Inference can be performed with a 16GB VRAM GPU.