Closed manmay-nakhashi closed 1 year ago
@lucidrains can you review this changes ?
@manmay-nakhashi thank you Manmay! will take a look later today and get it merged 🙌
@manmay-nakhashi could you also paste a small script for running this end to end in a comment in the PR?
almost there, i'll circle back to this tomorrow morning! thank you Manmay!
@manmay-nakhashi fast response! once you get a working script (like the examples in the readme), let's 🚢 it
import torch
from naturalspeech2_pytorch import (
EncodecWrapper,
Model,
NaturalSpeech2,
SpeechPromptEncoder
)
# use encodec as an example
codec = EncodecWrapper()
model = Model(
dim = 128,
depth = 6,
dim_prompt = 512,
cond_drop_prob = 0.25, # dropout prompt conditioning with this probability, for classifier free guidance
condition_on_prompt = True
)
# natural speech diffusion model
diffusion = NaturalSpeech2(
model = model,
codec = codec,
timesteps = 1000
).cuda()
# mock raw audio data
raw_audio = torch.randn(4, 327680).cuda()
prompt = torch.randn(4, 32768).cuda() # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically
text = torch.randint(0, 100, (4, 100)).cuda()
mel_len = torch.tensor([120, 60 , 80, 70]).cuda()
mel = torch.randn(4,80,120).cuda()
text_lens = torch.tensor([100, 50 , 80, 120]).cuda()
pitch = torch.randn(4, 1, 120).cuda()
loss = diffusion(text=text,text_lens=text_lens, mel=mel, mel_len=mel_len,pitch=pitch,audio=raw_audio, prompt = prompt) # pass in the prompt
loss.backward()
@manmay-nakhashi is the mel derived from the raw audio?
Yes
@manmay-nakhashi ok, so we could take care of this automatically within the framework
might be the next logical step!
@manmay-nakhashi where would the pitch come from?
@lucidrains there is a function called compute_pitch in naturalspeech2_pytorch we can use that to extract pitch from wav in paper they are also using pitch extraction from pyworld.
@manmay-nakhashi got it! i'll take care of that
ideally the researcher has to pass in as little information as possible, the rest computed, but still giving the option to pass everything in manually
@lucidrains i have made a skeleton for diffusion conditioning, we can keep building on top of this.