Wataru-Nakata / miipher

Unofficial implementation of miipher
MIT License
89 stars 14 forks source link

About training #8

Open jjjanicehuang opened 3 weeks ago

jjjanicehuang commented 3 weeks ago

Hi Wataru,

I'm now at the training stage and have a few questions to ask:

Thanks!

Wataru-Nakata commented 3 weeks ago
  1. Yes but it doesn't change anything. If you wanna control the frequency of validation logging, there is a argument in the trainer to do so
  2. 3300 is set but I didn't train it for 3300 epochs, I trained the model for 400k steps which is the same as the original miipher paper.
  3. what do you mean by not synthesized well? does it show any wav file?
  4. I will share my wandb log for the training of miipher. Maybe you can get insight on how the training will be like https://wandb.ai/wataru9871/miipher/runs/f79mw0g1?nw=nwuserwataru9871
jjjanicehuang commented 3 weeks ago

Thanks for the hints. I did find the trainer.py and this script doesn't set max/min steps or max/min epochs, I assume the model will figure it out itself. I also registered for wandb.ai and visualized the training process, I witnessed that one epoch goes 20k steps, and given my dataset is only 100hrs with 20K+ audio samples, any suggestions you may give upon expediting the process? Holding for 2 weeks would be so long for me.

BTW, thanks for sharing the link.

jjjanicehuang commented 3 weeks ago

for example: will each epoch take the averagely same time to run? will the process be faster along the increased epochs? if not, is there any way to do so?

Wataru-Nakata commented 3 weeks ago

Please refer to my training log on the wandb for them

jjjanicehuang commented 3 weeks ago

thx, one more question, in your datamodule.py, you specified 20000 for train and 30004 for val, what does the 4 mean?

To my understanding, if I got 24000 audio samples in total, I should assign a certain percentage to each group, i.e. 20000 for train and 4000 for val, am I right?

def setup(self, stage: str):
        self.train_dataset = (
            wds.WebDataset(
                self.cfg.data.train_dataset_path,
                resampled=True,
                nodesplitter=wds.split_by_node,
            )
            .shuffle(1000) 
            .decode(wds.torch_audio)
            # .decode(self.decode_phoneme_input)
            .repeat(2) 
            .with_length(20000 * self.cfg.data.train_batch_size) 
        )

        self.val_dataset = (
            wds.WebDataset(
                self.cfg.data.val_dataset_path, nodesplitter=wds.split_by_node
            )
            .decode(wds.torch_audio)
            # .decode(self.decode_phoneme_input)
            .repeat(2)
            .with_length(3000 * 4 // self.cfg.data.val_batch_size)
Wataru-Nakata commented 3 weeks ago

That's just some hyperparameter to limit the size of the validation dataset and training dataset. In case you may not know, changing those hyperparameter will not make your training faster. If you wanna make the training faster, consider using pytorch profiler to fiind the bottleneck of the code. Or, you can use ddp to utilize multiple gpus for training.

Wataru-Nakata commented 3 weeks ago

The train, val split is handled by the preprocessing stage. so if you have done the preprocessing right, it should'nt be the problem

jjjanicehuang commented 3 weeks ago

OK, thank you for all the details!

jjjanicehuang commented 3 weeks ago

Btw, is there any setting related to restore/resume from a certain step? i.e. if the training crashed the command could lead to the latest ckpt instead of from scratch. And by default, I noticed the model just save ckpt only per epoch, right?

Wataru-Nakata commented 3 weeks ago

Yes. That's righht. If you wanna load from the checkpoint you can refer to the documentation on pytorch lihgnting trainer https://lightning.ai/docs/pytorch/stable/common/trainer.html#fit

jjjanicehuang commented 2 weeks ago

thanks for providing the link.

I noticed that layer 8 was specified in the config.yaml for the pre-trained WavLM-large model, may I ask why this layer was chosen? According to the paper, layer 8 is associated with Automatic Speaker Verification (ASV), but it's not the strongest layer for this function.

model:
  ssl_models:
    model:
      _target_: transformers.AutoModel.from_pretrained
      pretrained_model_name_or_path: "microsoft/wavlm-large"
    sr: 16_000
    layer: 8
Wataru-Nakata commented 2 weeks ago

We followed the original paper of miipher which used 8th layer of w2v-bert. However, other layer might be better. I didn't perform any investigations on this