lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

When running the example code, I get an error where the trainer says to be instantiated twice: #211

Closed LukasNel closed 1 year ago

LukasNel commented 1 year ago
#!/usr/bin/env python
# coding: utf-8

# In[1]:

# ## Setup
# 
# Includes:
# 
# - How to generate a placeholder dataset if you haven't already, just the basics to run "training" e2e on a tiny dataset
# - How to download a dataset from OpenSLR

# ### Imports & paths

# In[6]:

# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio

# define all dataset paths, checkpoints, etc
dataset_folder = "dev-clean"
soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

# ### Data

# In[ ]:

# Placeholder data generation
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):
  # code adapted from https://stackoverflow.com/a/33913403
  audio = []
  num_samples = duration_ms * (sample_rate / 1000.0)
  for x in range(int(num_samples)):
    audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))
  return audio

def save_wav(file_name, audio, sample_rate=44100.0):
  # Open up a wav file
  wav_file=wave.open(file_name,"w")
  # wav params
  nchannels = 1
  sampwidth = 2
  # 44100 is the industry standard sample rate - CD quality.  If you need to
  # save on file size you can adjust it downwards. The stanard for low quality
  # is 8000 or 8kHz.
  nframes = len(audio)
  comptype = "NONE"
  compname = "not compressed"
  wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))
  # WAV files here are using short, 16 bit, signed integers for the
  # sample size.  So we multiply the floating point data we have by 32767, the
  # maximum value for a short integer.  NOTE: It is theortically possible to
  # use the floating point -1.0 to 1.0 data directly in a WAV file but not
  # obvious how to do that using the wave module in python.
  for sample in audio:
      wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))
  wav_file.close()
  return

def make_placeholder_dataset():
  # Make a placeholder dataset with a few .wav files that you can "train" on, just to verify things work e2e
  if os.path.isdir(dataset_folder):
    return
  os.makedirs(dataset_folder)
  save_wav(f"{dataset_folder}/example.wav", get_sinewave())
  save_wav(f"{dataset_folder}/example2.wav", get_sinewave(duration_ms=500))
  os.makedirs(f"{dataset_folder}/subdirectory")
  save_wav(f"{dataset_folder}/subdirectory/example.wav", get_sinewave(freq=330.0))

make_placeholder_dataset()

# In[6]:

# Get actual dataset. Uncomment this if you want to try training on real data

# full dataset: https://www.openslr.org/12
# We'll use https://us.openslr.org/resources/12/dev-clean.tar.gz development set, "clean" speech.
# We *should* train on, well, training, but this is just to demo running things end-to-end at all so I just picked a small clean set.

url = "https://us.openslr.org/resources/12/dev-clean.tar.gz"
filename = "dev-clean"
filename_targz = filename + ".tar.gz"
if not os.path.isfile(filename_targz):
  urllib.request.urlretrieve(url, filename_targz)
if not os.path.isdir(filename):
  # open file
  with tarfile.open(filename_targz) as t:
    t.extractall(filename)

# ## Training
# 
# Now that we have a dataset, we can train AudioLM.
# 
# **Note**: do NOT type "y" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose "overwrite" then you lose the SoundStream checkpoint when you then train SemanticTransformer).

# ### SoundStream

# In[7]:
print("Training soundstream")

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
).cuda()
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

# In[8]:

trainer.train()

# In[17]:

print("Training semantic transformer")

# ### SemanticTransformer

# In[15]:

# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

print("Training coarse transformer")

# ### CoarseTransformer

# In[ ]:

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

# ### FineTransformer

# In[ ]:

print("Training fine transformer")

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

# ## Inference

# In[ ]:

print("Generating audio")

# Everything together
audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

# In[ ]:

output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)

# In[ ]:

Error

8: saving to results
8: saving model to results
training complete
Training semantic transformer
Downloading (…)lve/main/config.json: 100%|██████████████████████████████████████████████████████| 605/605 [00:00<00:00, 382kB/s]
Traceback (most recent call last):
  File "audiolm.py", line 178, in <module>
    trainer = SemanticTransformerTrainer(
  File "<@beartype(audiolm_pytorch.trainer.SemanticTransformerTrainer.__init__) at 0x2b8db50c15e0>", line 116, in __init__
  File "/vast/palmer/home.grace/lan43/lam/inception/lib/python3.8/site-packages/audiolm_pytorch/trainer.py", line 586, in __init__
    check_one_trainer()
  File "/vast/palmer/home.grace/lan43/lam/inception/lib/python3.8/site-packages/audiolm_pytorch/trainer.py", line 59, in check_one_trainer
    assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training'
AssertionError: only one Trainer can be instantiated at a time for training

The code to check that seems to be fairly newly added.

LWprogramming commented 1 year ago

Ah yes this looks like my notebook from a while back-- I think we want to train one at a time now, see here. I use this script now, which trains them separately (see the last few lines of the shell script)

LukasNel commented 1 year ago

duude, that's awesome thanks so much. You should definitely link that repo to this one.

LWprogramming commented 1 year ago

haha, it's a bit rough around the edges because I keep changing it to try to make my training runs work properly. maybe if I get something working stably then we can talk to lucidrains :P

lucidrains commented 1 year ago

cool, it worked 😄

lucidrains commented 1 year ago

yea, like it said, you need to train each network one at a time (separate training script), then you can chain them altogether

nischalj10 commented 5 months ago

Hey @lucidrains @LWprogramming, I am not able to wrap my head around the issue. I am trying to run on a single T4 gpu each notebook cell sequentially. I am still getting the Assertion Error.