lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.33k stars 249 forks source link

How to run the inference? #213

Closed LianaN closed 12 months ago

LianaN commented 12 months ago

Hello, I trained a model on a custom audio dataset in Spanish (wav files - audio only, without transcripts). How can I properly load the trained model from checkpoints and do the inference, i.e. text-to-speech?

This is my current code, but it generates a constant noise signal (kind of, bi-ii-p) in out.wav.

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
).cuda()

soundstream = SoundStream.init_and_load_from("./results/soundstream.8.pt").cuda()

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
).cuda()

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
).cuda()

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav_with_text_condition = audiolm(text = ['Hola, que tal?'])

output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav_with_text_condition.cpu(), sample_rate)
LWprogramming commented 12 months ago

You need to actually train the transformers (see #211 for needing to train them separately, and here for the training script I use which might be helpful to adapt). You can save or load checkpoints using the trainers' (or transformers') save/load() methods

LianaN commented 12 months ago

I tried loading trained transformers as follows, but I am getting AssertionError: only one Trainer can be instantiated at a time for training. I created a separate notebook per each transformer training, which worked well. But I'm unclear how to properly assemble all pieces for the inference.

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
).cuda()

soundstream = SoundStream.init_and_load_from(soundstream_ckpt).cuda()

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()
semantictransformer_trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)
semantictransformer_trainer.load(semantictransformer_ckpt)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
).cuda()
coarsetransformer_trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
coarsetransformer_trainer.load(coarsetransformer_ckpt)

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
).cuda()
finetransformer_trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    grad_accum_every = 32,
    data_max_length_seconds = 4,
    num_train_steps = 9
)
finetransformer_trainer.load(finetransformer_ckpt)

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)
LWprogramming commented 12 months ago

I'd check out #211 and https://github.com/LWprogramming/audiolm-pytorch-training/blob/3374046d801451aec5a441d37fae2e9b7b8b05af/audiolm_pytorch_demo_laion.py#L369 (load from the transformer, don't create multiple trainers when evaluating)

tiansiyuan commented 12 months ago

@LianaN Have a try with version 1.2.15.

LianaN commented 12 months ago

Thanks, I was able to load from the transformers using the version 1.2.20. In the previous versions I was getting the error:

AttributeError: 'SemanticTransformer' object has no attribute 'load'
tiansiyuan commented 12 months ago

Thanks, I was able to load from the transformers using the version 1.2.20. In the previous versions I was getting the error:

AttributeError: 'SemanticTransformer' object has no attribute 'load'

Try using pickle to save trained transformers into files and load them back before inference.