yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.95k stars 416 forks source link

ETA of training code publication #1

Closed spolezhaev closed 1 year ago

spolezhaev commented 1 year ago

Thank you for your work! Is there any ETA on when the training and inference code will become available?

yl4579 commented 1 year ago

It will likely be July or August. It depends on how well my other project goes and how many GPU resources I will get in my lab. The GPUs are very busy during the summer as students are working on their projects full time, and some other projects will be the priority before I can work on cleaning and testing the code. If you are interested, you can start with the StyleTTS w/ PL-BERT code and try to code it yourself. I believe the most important part is the adversarial training and style diffusion, so I will provide the code snippets here. The code is not tested nor cleaned, but it was copied from the Jupyter notebook I ran the experiment with.

Style diffusion: (you will need pip install audio-diffusion-pytorch== 0.0.96 this specific version as it has EDM, in the cleaned-up code I will try to incorporate it into the codebase)

from audio_diffusion_pytorch.modules import *

class Transformer1d(nn.Module):
    def __init__(
        self,
        num_layers: int,
        channels: int,
        num_heads: int,
        head_features: int,
        multiplier: int,
        use_context_time: bool = True,
        use_rel_pos: bool = False,
        context_features_multiplier: int = 1,
        rel_pos_num_buckets: Optional[int] = None,
        rel_pos_max_distance: Optional[int] = None,
        context_features: Optional[int] = None,
        context_embedding_features: Optional[int] = None,
        embedding_max_length: int = 512,
    ):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    features=channels + context_embedding_features,
                    head_features=head_features,
                    num_heads=num_heads,
                    multiplier=multiplier,
                    use_rel_pos=use_rel_pos,
                    rel_pos_num_buckets=rel_pos_num_buckets,
                    rel_pos_max_distance=rel_pos_max_distance,
                )
                for i in range(num_layers)
            ]
        )

        self.to_out = nn.Sequential(
            Rearrange("b t c -> b c t"),
            Conv1d(
                in_channels=channels + context_embedding_features,
                out_channels=channels,
                kernel_size=1,
            ),
        )

        use_context_features = exists(context_features)
        self.use_context_features = use_context_features
        self.use_context_time = use_context_time

        if use_context_time or use_context_features:
            context_mapping_features = channels + context_embedding_features

            self.to_mapping = nn.Sequential(
                nn.Linear(context_mapping_features, context_mapping_features),
                nn.GELU(),
                nn.Linear(context_mapping_features, context_mapping_features),
                nn.GELU(),
            )

        if use_context_time:
            assert exists(context_mapping_features)
            self.to_time = nn.Sequential(
                TimePositionalEmbedding(
                    dim=channels, out_features=context_mapping_features
                ),
                nn.GELU(),
            )

        if use_context_features:
            assert exists(context_features) and exists(context_mapping_features)
            self.to_features = nn.Sequential(
                nn.Linear(
                    in_features=context_features, out_features=context_mapping_features
                ),
                nn.GELU(),
            )

        self.fixed_embedding = FixedEmbedding(
            max_length=embedding_max_length, features=context_embedding_features
        )

    def get_mapping(
        self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
    ) -> Optional[Tensor]:
        """Combines context time features and features into mapping"""
        items, mapping = [], None
        # Compute time features
        if self.use_context_time:
            assert_message = "use_context_time=True but no time features provided"
            assert exists(time), assert_message
            items += [self.to_time(time)]
        # Compute features
        if self.use_context_features:
            assert_message = "context_features exists but no features provided"
            assert exists(features), assert_message
            items += [self.to_features(features)]

        # Compute joint mapping
        if self.use_context_time or self.use_context_features:
            mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
            mapping = self.to_mapping(mapping)

        return mapping

    def run(self, x, time, embedding, features):

        mapping = self.get_mapping(time, features)
        x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
        mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)

        for block in self.blocks:
            x = x + mapping
            x = block(x)

        x = x.mean(axis=1).unsqueeze(1)
        x = self.to_out(x)
        x = x.transpose(-1, -2)

        return x

    def forward(self, x: Tensor, 
                time: Tensor, 
                embedding_mask_proba: float = 0.0,
                embedding: Optional[Tensor] = None, 
                features: Optional[Tensor] = None,
               embedding_scale: float = 1.0) -> Tensor:

        b, device = embedding.shape[0], embedding.device
        fixed_embedding = self.fixed_embedding(embedding)
        if embedding_mask_proba > 0.0:
            # Randomly mask embedding
            batch_mask = rand_bool(
                shape=(b, 1, 1), proba=embedding_mask_proba, device=device
            )
            embedding = torch.where(batch_mask, fixed_embedding, embedding)

        if embedding_scale != 1.0:
            # Compute both normal and fixed embedding outputs
            out = self.run(x, time, embedding=embedding, features=features)
            out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
            # Scale conditional output using classifier-free guidance
            return out_masked + (out - out_masked) * embedding_scale
        else:
            return self.run(x, time, embedding=embedding, features=features)

        return x
transformer = Transformer1d(
                num_layers=3,
                channels=256,
                num_heads=8,
                head_features=64,
                multiplier=2,
                context_embedding_features=768,
            )

from audio_diffusion_pytorch import AudioDiffusionConditional, DiffusionSampler

diffusion = AudioDiffusionConditional(
    in_channels=1,
    embedding_max_length=512,
    embedding_features=768,
    embedding_mask_proba=0.1, # Conditional dropout of batch elements,
    multipliers=[1, 2],
    channels=256,
    patch_size=16,
    factors=[2],
    attentions=[0, 1],
    num_blocks=[2]
)

diffusion.diffusion.net = transformer
diffusion.unet = transformer
diffusion.diffusion = KDiffusion(
    net=diffusion.unet,
    sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
    sigma_data=0.2,
    dynamic_threshold=0.0
)

sampler = DiffusionSampler(
    model.diffusion.diffusion,
    num_steps=5, 
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
    clamp=False
)

WavLM discriminator:

class WavLMLoss(torch.nn.Module):

    def __init__(self, model, wd):
        """Initilize spectral convergence loss module."""
        super(WavLMLoss, self).__init__()
        self.wavlm = WavLMModel.from_pretrained(model)
        self.wd = wd
        self.resample = torchaudio.transforms.Resample(24000, 16000)

    def wd_forward(self, wav, y_rec):
        with torch.no_grad():
            wav_16 = self.resample(wav)
            wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
        y_rec_16 = self.resample(y_rec)
        y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states

        y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
        y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)

        y_d_rs = self.wd(y_embeddings)
        y_d_gs = self.wd(y_rec_embeddings)

        y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs

        loss_gen_f = torch.mean((1-y_df_hat_g)**2)
        loss_rel = 0

        loss_gen_all = loss_gen_f + loss_rel

        return loss_gen_all

    def wd_discriminator(self, wav, y_rec):
        with torch.no_grad():
            wav_16 = self.resample(wav)
            wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
            y_rec_16 = self.resample(y_rec)
            y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states

            y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
            y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)

        y_d_rs = self.wd(y_embeddings)
        y_d_gs = self.wd(y_rec_embeddings)

        y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs

        r_loss = torch.mean((1-y_df_hat_r)**2)
        g_loss = torch.mean((y_df_hat_g)**2)

        loss_disc_f = r_loss + g_loss        
        loss_rel = 0
        d_loss = loss_disc_f + loss_rel

        return d_loss.mean()

wl = WavLMLoss('microsoft/wavlm-base-plus', model.wd).to('cuda')

Adversarial training run:

    for i, batch in enumerate(train_dataloader):
        waves = batch[0]
        batch = [b.to(device) for b in batch[1:]]
        texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, labels = batch

        # ... joint training code omitted 

        if np.random.rand() < 0.5:
            use_ind = True
        else:
            use_ind = False

        if use_ind:
            ref_lengths = input_lengths
            ref_texts = texts

        text_mask = length_to_mask(ref_lengths).to(texts.device)
        bert_dur = model.bert(ref_texts, attention_mask=(~text_mask).int()).last_hidden_state
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2) 

        if use_ind and np.random.rand() < 0.5:
            s_preds = s_trg
        else:
            num_steps = np.random.randint(3, 5)
            s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to('cuda'), 
                  embedding=bert_dur,
                  embedding_scale=1,
                     embedding_mask_proba=0.1,
                     num_steps=num_steps).squeeze(1)

        s_dur = s_preds[:, 128:]
        s = s_preds[:, :128]

        d, _ = model.predictor(d_en, s_dur, 
                                                ref_lengths, 
                                                torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to('cuda'), 
                                                text_mask)

        bib = 0

        output_lengths = []
        attn_preds = []
        for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), ref_lengths):

            _s2s_pred_org = _s2s_pred[:_text_length, :]

            _s2s_pred = torch.sigmoid(_s2s_pred_org)
            _dur_pred = _s2s_pred.sum(axis=-1)
            _text_input = _text_input[:_text_length].long()

            l = int(torch.round(_s2s_pred.sum()).item())
            t = torch.arange(0, l).expand(l)

            t = torch.arange(0, l).unsqueeze(0).expand((len(_s2s_pred), l)).to('cuda')
            loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2

            sig = 1.5
            h = torch.exp(-0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (sig)**2)

            out = torch.nn.functional.conv1d(_s2s_pred_org.unsqueeze(0), 
                                         h.unsqueeze(1), 
                                         padding=h.shape[-1] - 1, groups=int(_text_length))[..., :l]
            attn_preds.append(F.softmax(out.squeeze(), dim=0))

            output_lengths.append(l)

        max_len = max(output_lengths)

        with torch.no_grad():
            t_en = model.text_encoder(ref_texts, ref_lengths, text_mask)

        s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to('cuda')
        for bib in range(len(output_lengths)):
            s2s_attn[bib, :ref_lengths[bib], :output_lengths[bib]] = attn_preds[bib]

        asr_pred = t_en @ s2s_attn

        _, p_pred = model.predictor(d_en, s_dur, 
                                                ref_lengths, 
                                                s2s_attn, 
                                                text_mask)

        mel_len = max(int(min(output_lengths) / 2 - 1), 200)
        mel_len = min(mel_len, 250)

        en = []
        p_en = []
        sp = []
        l = []

        F0_fakes = []
        N_fakes = []

        for bib in range(len(output_lengths)):
            mel_length_pred = output_lengths[bib]
            mel_length_gt = int(mel_input_length[bib].item() / 2)
            if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
                continue

            sp.append(s_preds[bib])

            random_start = np.random.randint(0, mel_length_pred - mel_len)
            en.append(asr_pred[bib, :, random_start:random_start+mel_len])
            p_en.append(p_pred[bib, :, random_start:random_start+mel_len])

            l.append(labels[bib])

        if len(sp) <= 1:
            continue

        sp = torch.stack(sp)

        en = torch.stack(en)
        p_en = torch.stack(p_en)

        labels = torch.stack(l)

        F0_fake, N_fake = model.predictor.F0Ntrain(p_en, sp[:, 128:])
        y_pred = model.decoder(en, F0_fake, N_fake, sp[:, :128])

        wav = y_rec_gt_pred

        optimizer.zero_grad()
        d_loss = wl.wd_discriminator(wav.detach().squeeze(), y_pred.detach().squeeze()).mean()

        d_loss.backward()
        optimizer.step('wd')

        # generator loss
        optimizer.zero_grad()

        loss_gen_lm = wl.wd_forward(wav.squeeze(), y_pred.squeeze())

        loss_gen_lm = loss_gen_lm.mean()

        loss_gen_lm.backward(retain_graph=True)

        total_norm = {}
        for key in model.keys():
            total_norm[key] = 0
            parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
            for p in parameters:
                param_norm = p.grad.detach().data.norm(2)
                total_norm[key] += param_norm.item() ** 2
            total_norm[key] = total_norm[key] ** 0.5

        if total_norm['predictor'] > 20:
            for key in model.keys():
                for p in model[key].parameters():
                    if p.grad is not None:
                        p.grad *= 2 * (1 / total_norm['predictor']) 

        for p in model.predictor.duration_proj.parameters():
            if p.grad is not None:
                p.grad *= 1e-2

        for p in model.predictor.lstm.parameters():
            if p.grad is not None:
                p.grad *= 1e-2

        for p in model.diffusion.parameters():
            if p.grad is not None:
                p.grad *= 1e-2

        optimizer.step('bert_encoder')
        optimizer.step('bert')
        optimizer.step('predictor')
        optimizer.step('diffusion')
        optimizer.step('style_encoder')
        optimizer.step('predictor_encoder') # this is the prosodic style encoder, will rename it later in cleaned-up code
        optimizer.step('decoder')

I will leave this issue open in case someone else is interested in the implementation.

spolezhaev commented 1 year ago

Thank you for such a quick response! The style diffusion part was the primary concern for the implementation. Will try to reimplement the paper, thanks!

talipturkmen commented 1 year ago

Hello, thank you for sharing the code. Could also share the duration predictor part as well?

yl4579 commented 1 year ago

@talipturkmen It is a very simple change:

self.duration_proj = LinearNorm(d_hid, 1) to self.duration_proj = LinearNorm(d_hid, L) where L=50.

shreyasinghal-17 commented 1 year ago

Any updates?

yl4579 commented 1 year ago

@spolezhaev Sorry for keeping you waiting. I'm now halfway done with my other projects and have started working on code cleaning. I will make sure the code is available by the end of this month.

nivibilla commented 1 year ago

Just wanted to say, amazing work. This is reaching almost tortoise-tts level quality but super fast inferencing of styletts. Cant wait to try this model out and maybe even finetuning it!

WendongGan commented 1 year ago

Looking forward to it!

yl4579 commented 1 year ago

@nivibilla not sure if you are referring to this one, but perceptually speaking it doesn't even sound better than VITS with language models (BERT), so do you remind explaining why people are so excited about it, especially given its insanely slow inference speed? Is it because it was trained on millions of hours of speech so people like its zero-shot speaker adaptation ability?

ghost commented 1 year ago

No because one can finetune it too, on any voice, god forbid if it had low VRAM usage and some genuine ability to up the inference speed, I doubt there was any need to look for any other low-resource human-level speech synthesis model

nivibilla commented 1 year ago

@yl4579 the original author removed finetuning code and nerfed it on purpose. However there are other branches that people have made. And even personally when I've tested finetuning it can capture the style and nuances in the voice really well. Out of the box it isn't that good. But finetuning on just an hour of data or so yields the most natural sounding tts.

yl4579 commented 1 year ago

@nivibilla @exllama-fan I think that makes sense when the base model is large enough and pre-trained on enormous datasets, but unfortunately I don't believe you can fine-tune StyleTTS 2 to get similar performance especially with only one hour of data, because the biggest model I have so far was trained on only 585 hours of data (LibriTTS-R), incomparable to tortoise TTS trained on millions of hours. I simply don't have that huge amount of data (I believe tortoise TTS was trained by someone from OpenAI?) I do plan to train a model on Multilingual LibriSpeech with speech restoration like those used in LibriTTS-R, but our lab has other priorities for GPUs so I'm not sure when I would have time for this.

astricks commented 1 year ago

@yl4579 I am planning on writing a multilingual extension for this model myself. In addition to adding a language embedding, from what I can tell, both the WavLM model as well as the PL-BERT model will have to be replaced with multilingual versions?

I have access to some GPU resources for a few weeks, happy to train a model and share back results.

yl4579 commented 1 year ago

For those who are waiting for the code, I apologize for the delay, but I'm having difficulties reproducing the results I got from Jupyter notebooks. After a few weeks of code cleaning, I found there were substantial performance differences between models trained with the original notebooks and with the cleaned code. I'm still investigating the causes here, so the code release will be delayed. If anyone is interested in reproducing the results with the notebooks I have and helping me clean the code, please email me at yl4579@columbia.edu and I'm happy to provide the Jupyter notebooks and the cleaned code I used for the experiments.

bhairavmehta95 commented 1 year ago

just shot you a note from my gmail @yl4579 - happy to help clean + retrain

astricks commented 1 year ago

@yl4579 +1, sent you an email, happy to help clean the code and reproduce the results.

WendongGan commented 1 year ago

+1,happy to help clean the code and reproduce the results.

yl4579 commented 1 year ago

@WendongGan I didn't get your email. Please email me at yl4579@columbia.edu.

WendongGan commented 1 year ago

@yl4579 I'm sorry! I just sent it, thank you for checking.

thanhkm commented 1 year ago

@yl4579 +1, just sent you an email. Happy to help!

lovebeatz commented 1 year ago

Is there any progress with the code-cleaning? Any further update on code-release?

yl4579 commented 1 year ago

@lovebeatz So far only one person has confirmed that stage I (acoustic pretraining) code is probably fine, but nobody has reported any success in fixing the second stage training (joint training) that shows discrepancy in F0 and norm loss between the Jupyter notebook and the cleaned code. I probably won’t have time to work on this until end of this month, but hopefully someone could get it fixed soon.

yl4579 commented 1 year ago

This issue is inexplicably weird. I’m not sure if I should release the problematic code to the public anyway and mark it as WIP or just wait for volunteers who have emailed me to work on it a little bit longer.

lovebeatz commented 1 year ago

So the notebook code works well? Do you have the pre-trained model ready? Also, tell me about the inference code.

yl4579 commented 1 year ago

The inference code seems to work. At least I didn’t notice any clear degradation in quality of synthesized speech between the cleaned code and the notebook I used to run the experiment. I couldn’t find the exact checkpoints I used for experiments in the paper but I do have the trained checkpoints as the reference. I can share that with you, though I’m not sure how useful it really is. You can email me for the inference code if you need that.

lovebeatz commented 1 year ago

I am looking to fine-tune using a notebook and infer using the cleaned code.

Also, I want to know for every new voice a separate fine-tuned model will be required, right?

yl4579 commented 1 year ago

Unfortunately, the cleaned code still doesn't work at this time because it produces higher losses with worse quality. The notebook is uncleaned and uncommented, so I don't think you can use it to fine-tune anything at this point unless you know how to read the code and modify it for your own purpose.

As for voices, I assume you mean speakers, so no, if you have multiple speakers you want to fine-tune with, you only need one model for as many speakers as you want. This is also the exact point of zero-shot speaker adaptation.

lovebeatz commented 1 year ago

But I read somewhere as you mentioned that for inference, styleTTS2 won't require any speaker reference unlike StyleTTS

Also, given what was seen in styletts, does this new version offer sentence breaks in speech, like styletts won't pause at a full stop (if there's a sentence afterwards)

yl4579 commented 1 year ago

@lovebeatz StyleTTS 2 doesn't require any reference for single speaker models, but it still needs a reference from the target speaker for multispeaker models because it needs to know which speaker you are about to synthesize. If your goal is to train a single speaker model like on the LJSpeech dataset, you don't need a reference.

As for the pauses, yes, StyleTTS 2 does have sentence breaks. I just synthesized your sentences above and confirmed it can (and surprisingly StyleTTS w/ PL-BERT can't even do that for some reason):

StyleTTS w/ PL-BERT: https://drive.google.com/file/d/1llMmllk9QyGYXBqsbRVjKzPQxcY7XQvW/view?usp=sharing StyleTTS 2: https://drive.google.com/file/d/1b_Bp9sLAOKPv9HMmAQD4SdrBC3GKHg_W/view?usp=sharing

VladisIove commented 1 year ago

@yl4579 , +1, wrote email message, can I help?

yl4579 commented 1 year ago

Funny enough, I might have identified where the problem is. I made a mistake in the notebook on the F0 model, and this mistake actually positively affected the model training for StyleTTS 2 (not StyleTTS). Hopefully, the code will be ready by the end of this month, though there's no guarantee because there might be some other mistakes, especially for the SLM adversarial training and diffusion part.

PleinAcces commented 1 year ago

Hi @yl4579 I sent you an e-mail, happy if I can help

Souvic commented 1 year ago

So, that means the final code will have that mistake right? Also, what is the deviation from the paper that helped you if I may ask? Seems interesting if it improves performance.

On Wed, Sep 20, 2023 at 11:58 AM Aaron (Yinghao) Li < @.***> wrote:

Funny enough, I might have identified where the problem is. I made a mistake in the notebook on the F0 model, and this mistake actually positively affected the model training for StyleTTS 2 (not StyleTTS). The code can probably be ready by the end of this month.

— Reply to this email directly, view it on GitHub https://github.com/yl4579/StyleTTS2/issues/1#issuecomment-1727048156, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACXUAMMKYAXVGWWXNMZUHW3X3KEIPANCNFSM6AAAAAAZHSO33Y . You are receiving this because you are subscribed to this thread.Message ID: @.***>

yl4579 commented 1 year ago

@Souvic In the notebook, I accidentally fixed the F0 model (put it inside the with torch.no_grad() block), and it actually produces better model in the second stage where the F0 loss is much lower than having the F0 model tuned. This is not the case for StyleTTS, I think it could be because of the additional waveform decoder and more discriminators that destabilize the F0 model training.

However, there are still some problems in the SLM adversarial training code that needs fixing, hopefully I can fix it by the end of this month so the code can be released by then.

PleinAcces commented 1 year ago

@yl4579 Do you plan to release the notebook too?

yl4579 commented 1 year ago

@PleinAcces The notebook is too messy so I don't think it is a good idea to make it public. If you want to take a look at the notebook though I can add you to the private code cleaning repo.

PleinAcces commented 1 year ago

@yl4579 Yes please!

yl4579 commented 1 year ago

@PleinAcces Done, but I’ve fixed the code now so I believe the actual working code will be ready soon. Hopefully by the end of this month!

yl4579 commented 1 year ago

I'm still having trouble making the code with acceletorar (DDP) work. Do you think it is acceptable to release the DP version (less efficient in terms of RAM use) first and work on DDP later?

yl4579 commented 1 year ago

DDP works perfectly fine for the first stage, but the pain comes in the second stage, especially when the SLM adversarial training run starts. This uses the most RAM, but the F0 loss is still higher with DDP for some reason.

lovebeatz commented 1 year ago

I think if it doesn't affect the inference or the quality of the model, it shall be released, if you are looking for an opinion

yl4579 commented 1 year ago

It's less efficient in terms of RAM (also the speed because you have to reduce the batch size now), and you can't do mixed precision easily this way either, but it's indeed just a matter of engineering so it can be dealt with by engineering people more fluent in programming than me. So, I'm thinking of releasing the DP version first if I can't fix the code by the end of the month and see if anyone else is interested in fixing the DDP version.

ghost commented 1 year ago

Can you share the code of ddp training? I'm training with DDP with but super slow

yl4579 commented 1 year ago

@primepake I will make a new repo with the broken DDP code if I can't fix it by the end of this month so that other people can work on it later.

Pydataman commented 1 year ago

expect

yl4579 commented 1 year ago

Unfortunately whenever I turn DP into DDP the F0 loss is consistently higher, which is very weird. I probably have to release the DP code first and see if any expert in DDP can fix the code later.

teopapad92 commented 1 year ago

Thanks for the effort guys

danielmsu commented 1 year ago

Hi @yl4579, thank you for your work! I would like to reproduce the results from the paper and can also try to help with DDP issue. May I ask you to add me to the cleaning repo please?

yl4579 commented 1 year ago

@danielmsu The code with DP is almost done and I’ll probably push it in a couple of days. I will make another public repo with not working DDP code for those with expertise to give a hand.

yl4579 commented 1 year ago

@danielmsu Actually, the broken DDP code doesn't need a separate repo. I just opened a new issue #7 for this problem and copypasted the code there. The code can be tested under this repo directly.