lucidrains / naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
MIT License
1.26k stars 100 forks source link

@wonwooo #37

Open CHK-0000 opened 6 months ago

CHK-0000 commented 6 months ago

Can you provide the training code for that model?


import torch

from naturalspeech2_pytorch import Trainer, EncodecWrapper, Model, NaturalSpeech2, SpeechPromptEncoder

codec = EncodecWrapper()

def main(): model = Model( dim = 128, depth = 6, dim_prompt = 512, cond_drop_prob = 0.25, condition_on_prompt = True )

diffusion = NaturalSpeech2(
    model = model,
    codec = codec,
    timesteps = 50
)

raw_audio = torch.randn(4, 327680)
prompt = torch.randn(4, 32768)

text = torch.randint(0, 100, (4, 100))
text_lens = torch.tensor([100, 50 , 80, 100])

# forwards and backwards

loss = diffusion(
    audio = raw_audio,
    text = text,
    text_lens = text_lens,
    prompt = prompt,
    )

loss.backward()

# after much training

generated_audio = diffusion.sample(
    length = 1024,
    text = text,
    prompt = prompt,
    )

trainer = Trainer(
    diffusion_model = diffusion,
    folder = 'C:\\naturalspeech2-pytorch\\ansunghun',
    train_batch_size = 16,
    gradient_accumulate_every = 2,
    train_num_steps = 5,
    save_and_sample_every = 100,
)

trainer.train()
trainer.save_checkpoint('C:\\naturalspeech2-pytorch\\ansunghun\\checkpoint.pt')

if name == 'main': from multiprocessing import freeze_support freeze_support() main()


An error occurs in that code.


Traceback (most recent call last): File "test.py", line 62, in main() File "test.py", line 56, in main trainer.train() File "C:\naturalspeech2-pytorch\naturalspeech2_pytorch\naturalspeech2_pytorch.py", line 1875, in train loss = self.model(data) File "C:\Users\user.conda\envs\svc\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "C:\Users\user.conda\envs\svc\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(args, **kwargs) File "C:\naturalspeech2-pytorch\naturalspeech2_pytorch\naturalspeech2_pytorch.py", line 1522, in forward text_max_length = text.shape[-1] AttributeError: 'NoneType' object has no attribute 'shape'