An unofficial PyTorch implementation of VALL-E, utilizing the EnCodec encoder/decoder.
Besides a working PyTorch environment, the only hard requirement is espeak-ng
for phonemizing text:
espeak
/espeak-ng
.espeak-ng
.
PHONEMIZER_ESPEAK_LIBRARY
environment variable to specify the path to libespeak-ng.dll
.set PHONEMIZER_ESPEAK_LIBRARY="C:\Program Files\eSpeak NG\libespeak-ng.dll"
beforehand should fix this.Support on AMD systems with ROCm is mostly supported, but performance will vary.
Simply run pip install git+https://git.ecker.tech/mrq/vall-e
or pip install git+https://github.com/e-c-k-e-r/vall-e
.
I've tested this repo under Python versions 3.10.9
, 3.11.3
, and 3.12.3
.
My pre-trained weights can be acquired from here.
A script to setup a proper environment and download the weights can be invoked with ./scripts/setup.sh
. This will automatically create a venv
, and download the ar+nar-llama-8
weights and config file to the right place.
When inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update.
Training is very dependent on:
To quickly test if a configuration works, you can run python -m vall_e.models.ar_nar --yaml="./data/config.yaml"
; a small trainer will overfit a provided utterance.
If you already have a dataset you want, for example, your own large corpus or for finetuning, you can use your own dataset instead.
Set up a venv
with https://github.com/m-bain/whisperX/
.
faster-whisper
is an exercise left to the user at the moment.python3 -m venv venv-whisper
source ./venv-whisper/bin/activate
pip3 install torch torchvision torchaudio
pip3 install git+https://github.com/m-bain/whisperX/
Populate your source voices under ./voices/{group name}/{speaker name}/
.
Run python3 -m vall_e.emb.transcribe
. This will generate a transcription with timestamps for your dataset.
model_name
and batch_size
variables.Run python3 -m vall_e.emb.process
. This will phonemize the transcriptions and quantize the audio.
Run python3 -m vall_e.emb.similar
. This will calculate the top-k most similar utterances for each utterance for use with sampling.
Copy ./data/config.yaml
to ./training/config.yaml
. Customize the training configuration and populate your dataset.training
list with the values stored under ./training/dataset/list.json
.
./vall_e/config.py
for additional configuration details.Two dataset formats are supported:
./training/data/{group}/{speaker}/{id}.{enc|dac}
as a NumPy file, where enc
is for the EnCodec/Vocos backend, and dac
for the Descript-Audio-Codec backend.python3 -m vall_e.data --yaml="./training/config.yaml" --action=metadata
python3 -m vall_e.data --yaml="./training/config.yaml"
(metadata for dataset pre-load is generated alongside HDF5 creation)use_hdf5
in your config YAML.For single GPUs, simply running python3 -m vall_e.train --yaml="./training/config.yaml
.
For multiple GPUs, or exotic distributed training:
deepspeed
backends, simply running deepspeed --module vall_e.train --yaml="./training/config.yaml"
should handle the gory details.local
backends, simply run torchrun --nnodes=1 --nproc-per-node={NUMOFGPUS} -m vall_e.train --yaml="./training/config.yaml"
You can enter save
to save the state at any time, or quit
to save and quit training.
The lr
command will also let you adjust the learning rate on the fly. For example: lr 1.0e-3
will set the learning rate to 0.001
.
Some additional flags can be passed as well:
--eval
: only run the evaluation / validation pass, then exit afterwards.--eval-random-text-prompts
: use random text prompts for the evaluation pass, rather than the provided text prompts in the dataset.Finetuning can be done by training the full model, or using a LoRA.
Finetuning the full model is done the same way as training a model, but be sure to have the weights in the correct spot, as if you're loading them for inferencing.
For training a LoRA, add the following block to your config.yaml
:
loras:
- name : "arbitrary name" # whatever you want
rank: 128 # dimensionality of the LoRA
alpha: 128 # scaling factor of the LoRA
training: True
And that's it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. I found rank
and alpha
of 128 works fine.
To export your LoRA weights, run python3 -m vall_e.export --lora --yaml="./training/config.yaml"
. You should be able to have the LoRA weights loaded from a training checkpoint automagically for inferencing, but export them just to be safe.
Included is a helper script to parse the training metrics. Simply invoke it with, for example: python3 -m vall_e.plot --yaml="./training/config.yaml"
You can specify what X and Y labels you want to plot against by passing --xs tokens_processed --ys loss.nll stats.acc
As training under deepspeed
and Windows is not (easily) supported, under your config.yaml
, simply change trainer.backend
to local
to use the local training backend.
Creature comforts like float16
, amp
, and multi-GPU training should work under the local
backend, but extensive testing still needs to be done to ensure it all functions.
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLM architectures:
llama
: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
mixtral
: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.bitnet
: using this implementation of BitNet's transformer.
cfg.optimizers.bitnet=True
will make use of BitNet's linear implementation.transformer
: a basic attention-based transformer implementation, with attention heads + feed forwards.retnet
: using TorchScale's RetNet implementation, a retention-based approach can be used instead.
retnet-hf
: using syncdoth/RetNet with a HuggingFace-compatible RetNet model
mamba
: using state-spaces/mamba (needs to mature)
For audio backends:
encodec
: a tried-and-tested EnCodec to encode/decode audio.vocos
: a higher quality EnCodec decoder.
encodec
backend automagically, as there's no EnCodec encoder under vocos
descript-audio-codec
: boasts better compression and quality, but has issues with model convergence.
llama
-based models also support different attention backends:
torch.nn.functional.scaled_dot_product_attention
-based attention:
math
: torch's SDPA's math
kernelmem_efficient
: torch's SDPA's memory efficient (xformers
adjacent) kernelcudnn
: torch's SDPA's cudnn
kernelflash
: torch's SDPA's flash attention kernelxformers
: facebookresearch/xformers's memory efficient attentionflash_attn
: uses the available flash_attn
package (including flash_attn==1.0.9
through a funny wrapper)flash_attn_v100
: uses ZRayZzz/flash-attention-v100's Flash Attention for Volta (but doesn't work currently)fused_attn
: uses an implementation using triton
(tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a whiledefault
: uses the naive path for hte internal implementation (used for attention-debugging purposed)transformers
Llama*Attention implementations:
eager
: default LlamaAttention
sdpa
: integrated LlamaSdpaAttention
attention modelflash_attention_2
: integrated LlamaFlashAttetion2
attention modelauto
: determine the best fit from the aboveThe wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.
ROCm/flash-attention currently does not support Navi3 cards (gfx11xx), so first-class support for Flash Attention is a bit of a mess on Navi3. Using the howiejay/navi_support
branch can get inference support, but not training support (due to some error being thrown during the backwards pass) by:
/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h
:
#if defined(__HIPCC_RTC__)
#define __HOST_DEVICE__ __device__ static
#else
#include <climits>
#define __HOST_DEVICE__ __host__ __device__ static inline
#endif
pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-build-isolation
To export the models, run: python -m vall_e.export --yaml=./training/config.yaml
.
This will export the latest checkpoints, for example, under ./training/ckpt/ar+nar-retnet-8/fp32.pth
, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
Desite being called fp32.pth
, you can export it to a different precision type with --dtype=float16|bfloat16|float32
.
You can also export to safetensors
with --format=sft
, and fp32.sft
will be exported instead.
To synthesize speech: python -m vall_e <text> <ref_path> <out_path> --yaml=<yaml_path>
(or --model=<model_path>
)
Some additional flags you can pass are:
--language
: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.--task
: task to perform. Defaults to tts
, but accepts stt
for transcriptions.--max-ar-steps
: maximum steps for inferencing through the AR model. Each second is 75 steps.--device
: device to use (default: cuda
, examples: cuda:0
, cuda:1
, cpu
)--ar-temp
: sampling temperature to use for the AR pass. During experimentation, 0.95
provides the most consistent output, but values close to it works fine.--nar-temp
: sampling temperature to use for the NAR pass. During experimentation, the lower value, the better. Set to 0
to enable greedy sampling.--input-prompt-length
: the maximum duration the input prompt can be (~6 seconds is fine, longer durations lead to slower generations for "better" accuracy, as long as the model was trained against such input prompt durations)And some experimental sampling flags you can use too (your mileage will definitely vary, but most of these are bandaids for a bad AR):
--input-prompt-prefix
: (AR only) treats the input prompt as the initial response prefix, but...
tts-c
).--min-ar-temp
: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between [0.0, (n)ar-temp)
.
--top-p
: limits the sampling pool to top sum of values that equal P
% probability in the probability distribution.--top-k
: limits the sampling pool to the top K
values in the probability distribution.--repetition-penalty
: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.--repetition-penalty-decay
: modifies the above factor applied to scale based on how far away it is in the past sequence.--length-penalty
: (AR only) modifies the probability of the stop token based on the current sequence length. This is very finnicky due to the AR already being well correlated with the length.--beam-width
: (AR only) specifies the number of branches to search through for beam sampling.
B
spaces.--mirostat-tau
: (AR only) the "surprise value" when performing mirostat sampling.
--mirostat-eta
: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.--dry-multiplier
: (AR only) performs DRY sampling, the scalar factor.--dry-base
: (AR only) for DRY sampling, the base of the exponent factor.--dry-allowed-length
: (AR only) for DRY sampling, the window to perform DRY sampling within.--layer-skip
enables early-exit layer skipping if the model is confident enough (for compatible models)--layer-skip-exit-layer
: maximum layer to use--layer-skip-entropy-threshold
: the maximum the logits' entropy (confidence) needs to be before exiting early--layer-skip-varentropy-threshold
: the maximum the logits' varentropy (confidence spread) needs to be before exiting early--refine-on-stop
: (AR only) uses the last steps' logits for the entire final output sequence, rather than the step-by-step iterative sequence.
The ar+nar-tts+stt-llama-8
model has received additional training for a speech-to-text task against EnCodec-encoded audio.
Currently, the model only transcribes back into the IPA phonemes it was trained against, as an additional model or external program is required to translate the IPA phonemes back into text.
A Gradio-based web UI is accessible by running python3 -m vall_e.webui
. You can, optionally, pass:
--yaml=./path/to/your/config.yaml
: will load the targeted YAML--model=./path/to/your/model.sft
: will load the targeted model weights--listen 0.0.0.0:7860
: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference.The model can be prompted in creative ways to yield some interesting behaviors:
Synthesizing speech is simple:
Input Prompt
: The guiding text prompt. Each new line will be its own generated audio to be stitched together at the end.Audio Input
: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine.
Output
: The resultant audio.Inference
: Button to start generating the audio.Basic Settings
: Basic sampler settings for most uses.Sampler Settings
: Advanced sampler settings that are common for most text LLMs, but needs experimentation.All the additional knobs have a description that can be correlated to the above CLI flags.
Speech-To-Text phoneme transcriptions for models that support it can be done using the Speech-to-Text
tab.
This tab currently only features exploring a dataset already prepared and referenced in your config.yaml
. You can select a registered voice, and have it randomly sample an utterance.
In the future, this should contain the necessary niceties to process raw audio into a dataset to train/finetune through, without needing to invoke the above commands to prepare the dataset.
So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
stt
(Speech-to-Text) seems to be working fine for the most part.<RVQ 0-7><RVQ 0>
=> <RVQ 0-7><RVQ 0-1>
=> <RVQ 0-7><RVQ 0-2>
(etc.)Despite how lightweight it is in comparison to other TTS's I've meddled with, there are still some caveats, be it with the implementation or model weights:
tts-c
(VALL-E continuous) mode or modifying an input prompt enough to where its quantized representation differs enough from the output response the prompt derives from.model.experimental.p_rvq_levels: [0,0,0,0,0,0,0,1,2,3,4,5,6,7]
seems to help?path
-based dataloader sampling instead of speaker
-based (or group
-based) dataloader sampling.
Unless otherwise credited/noted in this README or within the designated Python file, this repository is licensed under AGPLv3.
EnCodec is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
This implementation was originally based on enhuiz/vall-e, but has been heavily, heavily modified over time. Without it I would not have had a good basis to muck around and learn.
@article{wang2023neural,
title={Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers},
author={Wang, Chengyi and Chen, Sanyuan and Wu, Yu and Zhang, Ziqiang and Zhou, Long and Liu, Shujie and Chen, Zhuo and Liu, Yanqing and Wang, Huaming and Li, Jinyu and others},
journal={arXiv preprint arXiv:2301.02111},
year={2023}
}
@article{defossez2022highfi,
title={High Fidelity Neural Audio Compression},
author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
journal={arXiv preprint arXiv:2210.13438},
year={2022}
}