lucidrains / naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
MIT License
1.29k stars 101 forks source link

adding a diffusion conditioning #24

Closed manmay-nakhashi closed 1 year ago

manmay-nakhashi commented 1 year ago

@lucidrains i have made a skeleton for diffusion conditioning, we can keep building on top of this.

manmay-nakhashi commented 1 year ago

@lucidrains can you review this changes ?

lucidrains commented 1 year ago

@manmay-nakhashi thank you Manmay! will take a look later today and get it merged 🙌

lucidrains commented 1 year ago

@manmay-nakhashi could you also paste a small script for running this end to end in a comment in the PR?

lucidrains commented 1 year ago

almost there, i'll circle back to this tomorrow morning! thank you Manmay!

lucidrains commented 1 year ago

@manmay-nakhashi fast response! once you get a working script (like the examples in the readme), let's 🚢 it

manmay-nakhashi commented 1 year ago
import torch
from naturalspeech2_pytorch import (
    EncodecWrapper,
    Model,
    NaturalSpeech2,
    SpeechPromptEncoder
)

# use encodec as an example

codec = EncodecWrapper()

model = Model(
    dim = 128,
    depth = 6,
    dim_prompt = 512,
    cond_drop_prob = 0.25,                  # dropout prompt conditioning with this probability, for classifier free guidance
    condition_on_prompt = True
)

# natural speech diffusion model

diffusion = NaturalSpeech2(
    model = model,
    codec = codec,
    timesteps = 1000
).cuda()

# mock raw audio data

raw_audio = torch.randn(4, 327680).cuda()
prompt = torch.randn(4, 32768).cuda()          # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically
text = torch.randint(0, 100, (4, 100)).cuda()
mel_len = torch.tensor([120, 60 , 80, 70]).cuda()
mel = torch.randn(4,80,120).cuda()
text_lens = torch.tensor([100, 50 , 80, 120]).cuda()
pitch = torch.randn(4, 1, 120).cuda()
loss = diffusion(text=text,text_lens=text_lens, mel=mel, mel_len=mel_len,pitch=pitch,audio=raw_audio, prompt = prompt) # pass in the prompt
loss.backward()
lucidrains commented 1 year ago

@manmay-nakhashi is the mel derived from the raw audio?

manmay-nakhashi commented 1 year ago

Yes

lucidrains commented 1 year ago

@manmay-nakhashi ok, so we could take care of this automatically within the framework

might be the next logical step!

lucidrains commented 1 year ago

@manmay-nakhashi where would the pitch come from?

manmay-nakhashi commented 1 year ago

@lucidrains there is a function called compute_pitch in naturalspeech2_pytorch we can use that to extract pitch from wav in paper they are also using pitch extraction from pyworld.

lucidrains commented 1 year ago

@manmay-nakhashi got it! i'll take care of that

ideally the researcher has to pass in as little information as possible, the rest computed, but still giving the option to pass everything in manually