Stability-AI / stable-audio-tools

Generative models for conditional audio generation
MIT License
2.24k stars 197 forks source link

How do I run stable audio tools from just command prompt? #91

Open scaruslooner opened 3 weeks ago

scaruslooner commented 3 weeks ago

I dont want to run the gradio GUI, I want use the commands directly in the command line. Example: "stableaudiotools -- prompt " tree blowing in the wind" --steps 77 --secondsstart 3 -- secondstotal 8 --cfgscale 7.5"

SoftologyPro commented 3 weeks ago

Here is my stand alone script based on their example code. You need the models and pre-requisities all setup yourself first. You can edit/add other parameters as you want.

import sys

sys.stdout.write("Imports ...\n")
sys.stdout.flush()

import torch
import torchaudio
import json
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.models.factory import create_model_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict
from stable_audio_tools.training.utils import copy_state_dict
import argparse

sys.stdout.write("Parsing arguments ...\n")
sys.stdout.flush()

def parse_args():
    desc = "Blah"
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, help="the prompt to generate the audio from")
    args = parser.parse_args()
    return args

args2=parse_args();

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Using device:', device, flush=True)
print(torch.cuda.get_device_properties(device), flush=True)
sys.stdout.flush()

path_model_config = '.\ckpt\model_config.json'
path_model_ckpt_path = '.\ckpt\model.ckpt'

sys.stdout.write(f"Creating model from {path_model_config}\n")
sys.stdout.flush()
# Load config from json file
with open(path_model_config) as f:
    model_config = json.load(f)
model = create_model_from_config(model_config)

sys.stdout.write(f"Loading model checkpoint from {path_model_ckpt_path}\n")
sys.stdout.flush()
copy_state_dict(model, load_ckpt_state_dict(path_model_ckpt_path))

sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

sys.stdout.write(f"Setting up conditioning\n")
sys.stdout.flush()

# Set up text and timing conditioning
conditioning = [{
    "prompt": "128 BPM tech house drum loop",
    "seconds_start": 0, 
    "seconds_total": 180
}]

sys.stdout.write(f"Generating audio\n")
sys.stdout.flush()

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    steps=100,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=sample_size,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device=device
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

sys.stdout.write(f"Saving audio\n")
sys.stdout.flush()

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", output, sample_rate)

sys.stdout.write(f"Done\n")
sys.stdout.flush()
Gregorein commented 3 weeks ago

@scaruslooner grab, with default params for ease of use @SoftologyPro thanks for the code btw, somehow my browser gradio keeps spinning its gears and never produces an output file

import sys
import argparse
import torch
import torchaudio
import json
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.models.factory import create_model_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict
from stable_audio_tools.training.utils import copy_state_dict

sys.stdout.write("Imports ...\n")
sys.stdout.flush()

def parse_args():
    desc = "Script to generate audio from text prompts"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("--prompt", type=str, default="tree blowing in the wind", help="the prompt to generate the audio from")
    parser.add_argument("--steps", type=int, default=100, help="the number of steps for the diffusion process")
    parser.add_argument("--secondsstart", type=int, default=0, help="the start time in seconds")
    parser.add_argument("--secondstotal", type=int, default=10, help="the total duration in seconds")
    parser.add_argument("--cfgscale", type=float, default=7.5, help="the CFG scale value")
    parser.add_argument("--file", type=str, default="output.wav", help="the output file name")
    return parser.parse_args()

sys.stdout.write("Parsing arguments ...\n")
sys.stdout.flush()

args = parse_args()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Using device:', device, flush=True)
if torch.cuda.is_available():
    print(torch.cuda.get_device_properties(device), flush=True)
sys.stdout.flush()

path_model_config = './ckpt/model_config.json'
path_model_ckpt_path = './ckpt/model.safetensors'

sys.stdout.write(f"Creating model from {path_model_config}\n")
sys.stdout.flush()

# Load config from json file
with open(path_model_config) as f:
    model_config = json.load(f)
model = create_model_from_config(model_config)

sys.stdout.write(f"Loading model checkpoint from {path_model_ckpt_path}\n")
sys.stdout.flush()
copy_state_dict(model, load_ckpt_state_dict(path_model_ckpt_path))

sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

sys.stdout.write("Setting up conditioning\n")
sys.stdout.flush()

# Set up text and timing conditioning
conditioning = [{
    "prompt": args.prompt,
    "seconds_start": args.secondsstart, 
    "seconds_total": args.secondstotal
}]

sys.stdout.write("Generating audio\n")
sys.stdout.flush()

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    steps=args.steps,
    cfg_scale=args.cfgscale,
    conditioning=conditioning,
    sample_size=sample_size,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device=device
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

sys.stdout.write("Saving audio\n")
sys.stdout.flush()

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save(args.file, output, sample_rate)

sys.stdout.write("Done\n")
sys.stdout.flush()