lucidrains / spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
MIT License
249 stars 18 forks source link

Change the `generate` code to check if a tensor, not specifically a torch.FloatTensor. #9

Closed itsjamie closed 11 months ago

itsjamie commented 11 months ago

The very specific check prevents the DatasetGenerator from running on the GPU.

Before this change:

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following 
scalar types: Long, Int; but got torch.cuda.FloatTensor instead 
(while checking arguments for embedding)

This was due to the wav2vec being skipped, leading to the audio data being passed through. You could run this code on the CPU, but it takes much longer.

This allows moving the process to the GPU.

itsjamie commented 11 months ago

Could change this to check the dtype rather than it subclasses Tensor.

Based on: https://pytorch.org/docs/stable/tensors.html

Looks like checking torch.float32 or torch.float as the dtype would match either on CPU or GPU?

lucidrains commented 11 months ago

@itsjamie hey Jamie, could you tell me the line number for that RuntimeError?

lucidrains commented 11 months ago

@itsjamie do you want to see if 0.4.5 fixes it? https://github.com/lucidrains/spear-tts-pytorch/commit/4437d697fe48dd0c0b7b8455870f04f2a3b4f130

itsjamie commented 11 months ago

Here's the stack:

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/spear_tts_pytorch.py:1339, in SemanticToTextDatasetGenerator.forward(self, max_length, beam_search_decode, **generate_kwargs)
   1336 counter = 0
   1338 for audio, in self.dl:
-> 1339     audio_semantic_ids, text_ids = self.model.generate(
   1340         source = audio,
   1341         source_type = 'speech',
   1342         target_type = 'text',
   1343         return_source = True,
   1344         max_length = max_length,
   1345         beam_search_decode = beam_search_decode,
   1346         **generate_kwargs
   1347     )
   1349     for audio_semantic_id, text_id in zip(audio_semantic_ids, text_ids):
   1351         if exists(self.audio_pad_id):

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/spear_tts_pytorch.py:91, in eval_decorator.<locals>.inner(self, *args, **kwargs)
     89 was_training = self.training
     90 self.eval()
---> 91 out = fn(self, *args, **kwargs)
     92 self.train(was_training)
     93 return out

File <@beartype(spear_tts_pytorch.spear_tts_pytorch.TextToSemantic.generate) at 0x7fafc5f18c10>:138, in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_94130629710528, __beartype_getrandbits, __beartype_object_140397048615664, __beartype_object_140397719292784, __beartype_object_140392996549888, *args, **kwargs)

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/spear_tts_pytorch.py:646, in TextToSemantic.generate(self, source, source_type, target_type, temperature, filter_logits_fn, filter_fn_kwargs, source_mask, max_length, beam_search_decode, spec_decode, spec_decode_gamma, spec_decode_lenience, beam_size, return_source, return_target_mask, cond_scale)
    642     source_mask = source != source_pad_id
    644 # source embedding
--> 646 source_emb = source_token_emb(source)
    648 source_emb = self.source_transformer(source_emb, mask = source_mask)
    650 # decode target

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162, in Embedding.forward(self, input)
    161 def forward(self, input: Tensor) -> Tensor:
--> 162     return F.embedding(
    163         input, self.weight, self.padding_idx, self.max_norm,
    164         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2204     # Note [embedding_renorm set_grad_enabled]
   2205     # XXX: equivalent to
   2206     # with torch.no_grad():
   2207     #   torch.embedding_renorm_
   2208     # remove once script supports set_grad_enabled
   2209     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
itsjamie commented 11 months ago

@itsjamie do you want to see if 0.4.5 fixes it? https://github.com/lucidrains/spear-tts-pytorch/commit/4437d697fe48dd0c0b7b8455870f04f2a3b4f130

I'll check it out in a few hours.

lucidrains commented 11 months ago

@itsjamie cool

are you doing this for voicebox or soundstorm TTS? just curious

itsjamie commented 11 months ago

I'm interested in Voicebox. But, for no "good" reason other than I had to pick something and stick with until I had something working.

For work, I'm interested in an alternative to elevenlabs for high quality on-device TTS.

I'm also interested in the backtranslation aspect of SPEAR-TTS just because we were interested in some opportunities to do low-resource language TTS.

For personal usage, I'm interested in fixing audio I record with the masking properties from Voicebox.

Right now I'm just trying to make sure I can train something and understand with a smaller dataset, and then I'll either convince my work to let me scale the training, or just commit to a few months of training on a 3080 :laughing:

lucidrains commented 11 months ago

@itsjamie sounds good, many are interested in this route 😄

ok, keep me updated!

itsjamie commented 11 months ago

Resolved by 4437d69. This allowed moving it onto the GPU to generate the dataset.

lucidrains commented 11 months ago

great! remember to share some of your results once / if you get something trained 😄

itsjamie commented 11 months ago

Any idea why this code..

generated_dataset = GeneratedAudioTextDataset(
    folder = './generated-audio-text-pairs'
)
finetune_decoder = TextToSemantic(
    dim = 256,
    num_text_token_ids = 32100,
    source_depth = 6,
    target_depth = 6,
    heads = 8,
    dim_head = 64,
    wav2vec = wav2vec,
    num_semantic_token_ids = wav2vec.codebook_size,
    attn_dropout = 0.5,
    ff_mult = 2,
    ff_dropout = 0.5
)
finetune_decoder.load('results/speech.speech.10000.pt')
trainer = SemanticToTextTrainer(
    model=finetune_decoder,
    dataset=generated_dataset,
    batch_size = 4,
    grad_accum_every = 4,
    lr = 2e-4,
    num_train_steps=10_001,
    num_warmup_steps=1_000,
    results_folder='results_s2t'
)
trainer.train()

Results in:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb](https://file+.vscode-resource.vscode-cdn.net/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb) Cell 16 line 2
     [14](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=13) finetune_decoder.load('results/speech.speech.10000.pt')
     [15](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=14) trainer = SemanticToTextTrainer(
     [16](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=15)     model=finetune_decoder,
     [17](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=16)     dataset=generated_dataset,
   (...)
     [23](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=22)     results_folder='results_s2t'
     [24](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=23) )
---> [25](vscode-notebook-cell:/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/demo.ipynb#X11sZmlsZQ%3D%3D?line=24) trainer.train()

File [~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/trainer.py:571](https://file+.vscode-resource.vscode-cdn.net/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/trainer.py:571), in SemanticToTextTrainer.train(self, log_fn)
    569 def train(self, log_fn = noop):
    570     while self.steps < self.num_train_steps:
--> 571         logs = self.train_step()
    572         log_fn(logs)
    574     self.print('training complete')

File [~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/trainer.py:528](https://file+.vscode-resource.vscode-cdn.net/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/~/Coding/github.com/lucidrains/spear-tts-pytorch/spear_tts_pytorch/trainer.py:528), in SemanticToTextTrainer.train_step(self)
    525     semantic_token_ids, grapheme_token_ids = next(self.dl_iter)
    527     loss, _ = self.train_wrapper(semantic_token_ids = semantic_token_ids, grapheme_token_ids = grapheme_token_ids)
--> 528     self.accelerator.backward(loss [/](https://file+.vscode-resource.vscode-cdn.net/) self.grad_accum_every)
    530     accum_log(logs, {'loss': loss.item() [/](https://file+.vscode-resource.vscode-cdn.net/) self.grad_accum_every})
    532 if exists(self.max_grad_norm):

File [~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/accelerate/accelerator.py:1985](https://file+.vscode-resource.vscode-cdn.net/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/accelerate/accelerator.py:1985), in Accelerator.backward(self, loss, **kwargs)
   1983     self.scaler.scale(loss).backward(**kwargs)
   1984 else:
-> 1985     loss.backward(**kwargs)

File [~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/_tensor.py:487](https://file+.vscode-resource.vscode-cdn.net/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/_tensor.py:487), in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File [~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py:200](https://file+.vscode-resource.vscode-cdn.net/home/jstackhouse/Coding/github.com/lucidrains/spear-tts-pytorch/~/Coding/github.com/lucidrains/spear-tts-pytorch/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py:200), in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
itsjamie commented 11 months ago

I was just testing it using the datasets that existed in the codebase already before I wrote the dataset that took the audio + transcript from the MLS dataset.

And I'm concerned there is something off in the generated code.

Could it have been how I moved the model and dataloader to the GPU when I generated the tensors to be saved using the SemanticToTextDatasetGenerator?

lucidrains commented 11 months ago

@itsjamie hmm, i'm not sure, and it doesn't seem like you are using the early exit layer (for spec decoding)

do you want to try 0.4.6? i'll debug this this afternoon