openai / jukebox

Code for the paper "Jukebox: A Generative Model for Music"
https://openai.com/blog/jukebox/
Other
7.83k stars 1.41k forks source link

Finetuning. AssertionError: Bins 7898, got label tensor([[7899]], device='cuda:0') #278

Closed Theehawau closed 1 year ago

Theehawau commented 1 year ago

Hello, I got this error while trying to finetune

mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,prior_1b_lyrics,all_fp16,cpu_ema --name=finetuned \ --sample_length=1048576 --bs=1 --aug_shift --aug_blend --audio_files_dir=/home/hawau.toyin/Documents/data/ \ --labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000 --min_duration=23.78

Error: Using CPU EMA Logging to logs/finetuned 0/5283 [00:02<?, ?it/s] Traceback (most recent call last): File "jukebox/train.py", line 350, in fire.Fire(run) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 127, in Fire component_trace = _Fire(component, args, context, name) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 366, in _Fire component, remaining_args) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/fire/core.py", line 542, in _CallCallable result = fn(varargs, kwargs) File "jukebox/train.py", line 332, in run train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps) File "jukebox/train.py", line 231, in train x_out, loss, _metrics = model(x, forw_kwargs) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward output = self._run_ddp_forward(*inputs, *kwargs) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward return module_to_run(inputs[0], kwargs[0]) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/home/hawau.toyin/Documents/jukebox/jukebox/prior/prior.py", line 349, in forward loss, metrics = self.z_forward(z=z, z_conds=z_conds, y=y, fp16=fp16, get_preds=get_preds) File "/home/hawau.toyin/Documents/jukebox/jukebox/prior/prior.py", line 323, in z_forward x_cond, y_cond, prime = self.get_cond(z_conds, y) File "/home/hawau.toyin/Documents/jukebox/jukebox/prior/prior.py", line 241, in get_cond y_cond, y_pos = self.y_emb(y) if self.y_cond else (None, None) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/home/hawau.toyin/Documents/jukebox/jukebox/prior/conditioners.py", line 142, in forward artist_emb = self.artist_emb(artist) File "/home/hawau.toyin/.conda/envs/jukebox/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/home/hawau.toyin/Documents/jukebox/jukebox/prior/conditioners.py", line 67, in forward assert (0 <= y).all() and (y < self.bins).all(), f"Bins {self.bins}, got label {y}" AssertionError: Bins 7898, got label tensor([[7899]], device='cuda:0')

Can anyone help with this?