openai / jukebox

Code for the paper "Jukebox: A Generative Model for Music"
https://openai.com/blog/jukebox/
Other
7.81k stars 1.4k forks source link

Learning rate annealing: key errors #124

Open gogobd opened 4 years ago

gogobd commented 4 years ago

I was trying to anneal / cool off the learning rate following the examples to build models from scratch but when I try

# python train.py --hps=small_vqvae,small_upsampler,all_fp16,cpu_ema --name=small_upsampler --sample_length=262144 --bs=4 --audio_files_dir=[………] --labels=False --train --test --aug_shift --aug_blend --restore_vqvae=logs/small_vqvae/checkpoint_latest.pth.tar --prior --levels=2 --level=0 --weight_decay=0.01 --save_iters=1000 --restore_prior=logs/small_prior/checkpoint_latest.pth.tar --lr_use_linear_decay --lr_start_linear_decay=547 --lr_decay=18

but I get

RuntimeError: Error(s) in loading state_dict for SimplePrior:
Missing key(s) in state_dict: "conditioner_blocks.0.x_emb.weight", "conditioner_blocks.0.cond.model.0.weight", "conditioner_blocks.0.cond.model.0.bias", "conditioner_blocks.0.cond.model.1.0.blocks.0.model.1.weight", [………] , "conditioner_blocks.0.cond.model.3.1.weight", "c
onditioner_blocks.0.cond.model.3.1.bias", "conditioner_blocks.0.ln.weight", "conditio
ner_blocks.0.ln.bias".
> /opt/miniconda/lib/python3.7/site-packages/torch/nn/modules/module.py(830)load_stat
e_dict()

I don't know which other file I would have to restore (I'm using 'logs/small_prior/checkpoint_latest.pth.tar') - maybe someone can help me with this?

ObscuraDK commented 4 years ago

I have been using these parameters. I don't know if they are right, but I have started to get a pretty decent output.

mpiexec -n 3 python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior --sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir=/home/vertigo/jukebox/learning2 --labels=False --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 --restore_prior=/home/vertigo/jukebox/logs/pretrained_vqvae_small_prior/checkpoint_latest.pth.tar --lr_use_linear_decay --lr_start_linear_decay=0 --lr_decay=0.9

btrude commented 4 years ago

--lr_decay=0.9

@ObscuraDK This means you are only invoking the lr decay for the last fraction of a step during training (ie, this is essentially doing nothing unless you are working with an absolutely massive dataset). 1 step = 1/x iteration during training. I can't speak to what the effective number of steps is given I don't know what you are training on, but I am finding that with small datasets the lr is probably too high by default and the decay maybe should persist for the entire duration of training (but I have yet to test this). You might find this link helpful to better understand whats going on here: https://www.jeremyjordan.me/nn-learning-rate/

leonardog27 commented 3 years ago

we are making prior level 2 training using Colab. We have a group on discord called What kind of special set up we need to do to dataset for lyrics and non lyrics training? https://discord.com/invite/6At7WwM