archinetai / audio-diffusion-pytorch

Audio generation using diffusion models, in PyTorch.
MIT License
1.91k stars 166 forks source link

could provide a example recipe? #51

Open gandolfxu opened 1 year ago

gandolfxu commented 1 year ago
  1. Based on an open source dataset
  2. Detailed training parameters
gandolfxu commented 1 year ago

Is the following example right?

import librosa
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
from loguru import logger

from audio_diffusion_pytorch import DiffusionModel
from audio_diffusion_pytorch import UNetV0
from audio_diffusion_pytorch import VDiffusion
from audio_diffusion_pytorch import VSampler

LEN = 2 ** 18

class AudioDataset(Dataset):
    def __init__(self, fpath):
        self.file_list =[]
        for line in open(fpath):
            line = line.strip()
            if not line:
                continue
            self.file_list.append(line)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_file = self.file_list[idx]
        audio, fs = torchaudio.load(audio_file)
        transform = torchaudio.transforms.Resample(fs, 48000)
        audio = transform(audio)
        if audio.shape[1] > LEN:
            offset = np.random.randint(0, audio.shape[1] - LEN)
        else:
            offset = 0
        return audio[:, offset:offset + LEN]

def collate_fn(batch):
    bsz = len(batch)
    out = torch.zeros(bsz, 2, LEN)
    for i, x in enumerate(batch):
        out[i, :, :x.shape[1]] = x # torch.from_numpy(x)
    return out

model = DiffusionModel(
    net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
    in_channels=2, # U-Net: number of input/output (audio) channels
    channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
    factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
    items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
    attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
    attention_heads=8, # U-Net: number of attention heads per attention item
    attention_features=64, # U-Net: number of attention features per attention item
    diffusion_t=VDiffusion, # The diffusion method used
    sampler_t=VSampler, # The diffusion sampler used
)
model.to('cuda:0')

train_dataset = AudioDataset('data/train.list')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=8,
    pin_memory=True
)

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""))

for i in range(6):
    for audio in tqdm(train_dataloader, desc='epoch %d' % i):
        audio = audio.to('cuda:0')
        loss = model(audio)
        logger.info('loss = %f' % loss.item())
        loss.backward()
    torch.save(model.state_dict(), 'model_%d.pt' % i)
flavioschneider commented 1 year ago

Looks like it's missing the optimizer, I'd suggest to follow a basic pytorch tutorial on how to setup the training loop.

deepak-newzera commented 1 year ago

@flavioschneider, Can you please provide an example script on how to train the model? And also how to get a dataset to train the model?

deepak-newzera commented 1 year ago

Is the following example right?

import librosa
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
from loguru import logger

from audio_diffusion_pytorch import DiffusionModel
from audio_diffusion_pytorch import UNetV0
from audio_diffusion_pytorch import VDiffusion
from audio_diffusion_pytorch import VSampler

LEN = 2 ** 18

class AudioDataset(Dataset):
    def __init__(self, fpath):
        self.file_list =[]
        for line in open(fpath):
            line = line.strip()
            if not line:
                continue
            self.file_list.append(line)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_file = self.file_list[idx]
        audio, fs = torchaudio.load(audio_file)
        transform = torchaudio.transforms.Resample(fs, 48000)
        audio = transform(audio)
        if audio.shape[1] > LEN:
            offset = np.random.randint(0, audio.shape[1] - LEN)
        else:
            offset = 0
        return audio[:, offset:offset + LEN]

def collate_fn(batch):
    bsz = len(batch)
    out = torch.zeros(bsz, 2, LEN)
    for i, x in enumerate(batch):
        out[i, :, :x.shape[1]] = x # torch.from_numpy(x)
    return out

model = DiffusionModel(
    net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
    in_channels=2, # U-Net: number of input/output (audio) channels
    channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
    factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
    items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
    attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
    attention_heads=8, # U-Net: number of attention heads per attention item
    attention_features=64, # U-Net: number of attention features per attention item
    diffusion_t=VDiffusion, # The diffusion method used
    sampler_t=VSampler, # The diffusion sampler used
)
model.to('cuda:0')

train_dataset = AudioDataset('data/train.list')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=8,
    pin_memory=True
)

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""))

for i in range(6):
    for audio in tqdm(train_dataloader, desc='epoch %d' % i):
        audio = audio.to('cuda:0')
        loss = model(audio)
        logger.info('loss = %f' % loss.item())
        loss.backward()
    torch.save(model.state_dict(), 'model_%d.pt' % i)

@gandolfxu Did you get the correct script to train the model? If yes please help me out with that. And also let me know about the dataset that you are using.

deepak-newzera commented 1 year ago

@flavioschneider Please let me know what kind of dataset can be used to train the model and how it should be structured.

jameshball commented 1 year ago

Hi @deepak-newzera I've written a simple training script here: https://github.com/jameshball/audio-diffusion/blob/master/train.py

It uses the LibriSpeech dataset and downloads it when you start the script.

You might need to change the data path defined at the top, and setup or remove the weights and biases (wandb) logging.

I'm currently having an issue with training where I get NaN: https://github.com/archinetai/audio-diffusion-pytorch/issues/52 but at least this code should give you something to start with.

deepak-newzera commented 1 year ago

@jameshball Thanks for your help. The dataset you provided seems to be a speech dataset. But I suppose it must be a music dataset, right?

kitchWWW commented 1 year ago

Hi @jameshball , seems like the training script you linked goes to a 404, do you have an updated link for this training loop somewhere you could share?

jameshball commented 1 year ago

I've made that repo private now but this is the version of the file I linked:

import torch
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
from audio_data_pytorch import LibriSpeechDataset, AllTransform

SAMPLE_RATE = 16000
BATCH_SIZE = 12
NUM_SAMPLES = 2**18

def create_model():
    return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        in_channels=1, # U-Net: number of input/output (audio) channels
        channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
        factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
        items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=VDiffusion, # The diffusion method used
        sampler_t=VSampler, # The diffusion sampler used
    )

def main():
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    dataset = LibriSpeechDataset(
        root="E:/librispeech",
        transforms=AllTransform(
            random_crop_size=NUM_SAMPLES,
            mono=True,
        ),
    )

    print(f"Dataset length: {len(dataset)}")

    torchaudio.save("test.wav", dataset[0], SAMPLE_RATE)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )

    model = create_model().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    run_id = wandb.util.generate_id()
    if args.run_id is not None:
        run_id = args.run_id
    print(f"Run ID: {run_id}")

    wandb.init(project="audio-diffusion", resume=args.resume, id=run_id)

    epoch = 0
    step = 0

    if args.checkpoint is not None:
        checkpoint_path = args.checkpoint
    else:
        checkpoint_path = f"checkpoint-{run_id}.pt"

    if wandb.run.resumed:
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(wandb.restore(checkpoint_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        step = epoch * len(dataloader)

    scaler = torch.cuda.amp.GradScaler()

    model.train()
    while epoch < 100:
        avg_loss = 0
        avg_loss_step = 0
        progress = tqdm(dataloader)
        for i, audio in enumerate(progress):
            optimizer.zero_grad()
            audio = audio.to(device)
            with torch.cuda.amp.autocast():
                loss = model(audio)
                avg_loss += loss.item()
                avg_loss_step += 1
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            progress.set_postfix(
                loss=loss.item(),
                epoch=epoch + i / len(dataloader),
            )

            if step % 500 == 0:
                # Turn noise into new audio sample with diffusion
                noise = torch.randn(1, 1, NUM_SAMPLES, device=device)
                with torch.cuda.amp.autocast():
                    sample = model.sample(noise, num_steps=100)

                torchaudio.save(f'test_generated_sound_{step}.wav', sample[0].cpu(), SAMPLE_RATE)
                del sample
                gc.collect()
                torch.cuda.empty_cache()

                wandb.log({
                    "step": step,
                    "epoch": epoch + i / len(dataloader),
                    "loss": avg_loss / avg_loss_step,
                    "generated_audio": wandb.Audio(f'test_generated_sound_{step}.wav', caption="Generated audio", sample_rate=SAMPLE_RATE),
                })

            if step % 100 == 0:
                wandb.log({
                    "step": step,
                    "epoch": epoch + i / len(dataloader),
                    "loss": avg_loss / avg_loss_step,
                })
                avg_loss = 0
                avg_loss_step = 0

            step += 1

        epoch += 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        wandb.save(checkpoint_path)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--run_id", type=str, default=None)
    return parser.parse_args()

if __name__ == "__main__":
    main()
jameshball commented 1 year ago

It has some hardcoded paths and wandb code that you might need to remove but worked nicely

deepak-newzera commented 1 year ago

@jameshball Did you succeed in training the model? Is it producing some sensible outputs? I am also willing to train using your script. But instead of LibriSpeechDataset, I would like to train the model on a set of wav files. If possible, can you guide how it can be done?

dustyatx commented 1 year ago

@jameshball did you successfully train? if so can you share what you did and learned?