openai / jukebox

Code for the paper "Jukebox: A Generative Model for Music"
https://openai.com/blog/jukebox/
Other
7.83k stars 1.41k forks source link

Argument missing bug when sampling #257

Open jwb95 opened 2 years ago

jwb95 commented 2 years ago

I followed the training procedure in the readme using the following commands:

small_vqvae: python jukebox/train.py --hps=small_vqvae --name=my_project --sample_length=262144 --bs=4 \ --audio_files_dir=/path/to/my/wavs --labels=False --train --aug_shift --aug_blend

small_prior: python jukebox/train.py --hps=small_vqvae,small_prior,all_fp16,cpu_ema --name=small_prior \ --sample_length=2097152 --bs=4 --audio_files_dir=/path/to/my/wavs --labels=False --train --test --aug_shift --aug_blend --restore_vqvae=/path/to/small_vqvae.tar --prior --levels=2 --level=1 --weight_decay=0.01 --save_iters=1000

small_upsampler: python jukebox/train.py --hps=small_vqvae,small_upsampler,all_fp16,cpu_ema --name=small_upsampler \ --sample_length=262144 --bs=4 --audio_files_dir=/path/to/my/wavs --labels=False --train --test --aug_shift --aug_blend \ --restore_vqvae=/path/to/my/small_vqvae.tar --prior --levels=2 --level=0 --weight_decay=0.01 --save_iters=1000

I then added the following to the end of hparams.py:

my_small_vqvae = Hyperparams(
    restore_vqvae='/path/to/my/small_vqvae.tar'
)
my_small_vqvae.update(small_vqvae)
HPARAMS_REGISTRY["my_small_vqvae"] = my_small_vqvae

my_small_prior = Hyperparams(
    restore_prior='/path/to/my/small_prior.tar',
    level=1,
    labels=False,
    # TODO For the two lines below, if `--labels` was used and the model is
    # trained with lyrics, find and enter the layer, head pair that has learned
    # alignment.
    alignment_layer=47,
    alignment_head=0,
)
my_small_prior.update(small_prior)
HPARAMS_REGISTRY["my_small_prior"] = my_small_prior

my_small_upsampler = Hyperparams(
    restore_prior='/path/to/my/small_upsamler.tar',
    level=0,
    labels=False,
)
my_small_upsampler.update(small_upsampler)
HPARAMS_REGISTRY["my_small_upsampler"] = my_small_upsampler

and updated the MODELS-dict in make_models.py as follows:

MODELS = {
    '5b': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b"),
    '5b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b_lyrics"),
    '1b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_1b_lyrics"),
    'my_model': ("small_vqvae", "my_small_upsampler", "small_prior")
    #'your_model': ("you_vqvae_here", "your_upsampler_here", ..., "you_top_level_prior_here")
}

Then when running python jukebox/sample.py --model=my_model --name=my_project --levels=2 --n_samples=6 --sample_length_in_seconds=20 --total_sample_length_in_seconds=20 --sr=22050 --n_samples=6 --hop_fraction=0.5,0.5,0.125 i get the following:

Using cuda True
{'name': 'my_model', 'levels': 2, 'n_samples': 6, 'sample_length_in_seconds': 20, 'total_sample_length_in_seconds': 20, 'sr': 22050, 'hop_fraction': (0.5, 0.5, 0.125)}
Setting sample length to 440832 (i.e. 19.99238095238095 seconds) to be multiple of 256
0: Loading vqvae in eval mode
Conditioning on 1 above level(s)
Checkpointing convs
Checkpointing convs
Checkpointing convs
Level:0, Cond downsample:8, Raw to tokens:32, Sample length:262144
Restored from /.../Jukebox/jukebox/logs/small_upsampler/checkpoint_latest.pth.tar
0: Loading prior in eval mode
Conditioning on 1 above level(s)
Traceback (most recent call last):
  File "jukebox/sample.py", line 279, in <module>
    fire.Fire(run)
  File "/home/jovyan/.conda_envs/jukebox2/lib/python3.7/site-packages/fire/core.py", line 127, in Fire
    component_trace = _Fire(component, args, context, name)
  File "/home/jovyan/.conda_envs/jukebox2/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire
    component, remaining_args)
  File "/home/jovyan/.conda_envs/jukebox2/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable
    result = fn(*varargs, **kwargs)
  File "jukebox/sample.py", line 276, in run
    save_samples(model, device, hps, sample_hps)
  File "jukebox/sample.py", line 181, in save_samples
    vqvae, priors = make_model(model, device, hps)
  File "/.../Jukebox/jukebox/jukebox/make_models.py", line 196, in make_model
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
  File "/.../Jukebox/jukebox/jukebox/make_models.py", line 196, in <listcomp>
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
  File "/.../Jukebox/jukebox/jukebox/make_models.py", line 170, in make_prior
    single_enc_dec=hps.single_enc_dec)
  File "/.../Jukebox/jukebox/jukebox/prior/prior.py", line 78, in __init__
    self.y_emb = LabelConditioner(n_time=self.n_time,include_time_signal=not self.x_cond,**y_cond_kwargs)
  File "/.../Jukebox/jukebox/jukebox/prior/conditioners.py", line 118, in __init__
    assert len(y_bins) == 2, f"Expecting (genre, artist) bins, got {y_bins}"
TypeError: object of type 'int' has no len()