lucidrains / e2-tts-pytorch

Implementation of E2-TTS, "Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS", in Pytorch
MIT License
228 stars 21 forks source link

Fix target mask at inference time #10

Closed lucasnewman closed 1 month ago

lucasnewman commented 1 month ago

I'm still working on getting the model to sample correctly, but I noticed a couple of issues while debugging:

1) The passed in conditioning + target duration was ignored, so the model never generated outside of the conditioning sequence without the duration predictor in place. 2) The target mask needs to exclude the conditioning, like what's done at training time with the random span mask.

I also made a fast path when CFG is disabled so we don't compute and throw away the prediction with no conditioning.

lucidrains commented 1 month ago

@lucasnewman thanks for catching the duration bug! :bug:

the CFG change also makes sense, forgot about that :facepalm:

so the cond_mask refers to the prefix conditioning length, while the mask is the self attention mask that should encompass both the condition as well as the target sequence that follows (cond sequence length + target length == duration length)

the way you have it the transformer would not be able to attend to the condition

lucasnewman commented 1 month ago

Ah, ok, my mistake! You can close this if you want!

lucidrains commented 1 month ago

@lucasnewman the other two changes look good! let me merge those in :pray: