lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

AssertionError: only one Trainer can be instantiated at a time for training #215

Open tiansiyuan opened 1 year ago

tiansiyuan commented 1 year ago

With version 1.2.16, when I run audiolm_pytorch_demo.ipynb, after I successfully complete training SoundStream and start training SemanticTransformer, I get the error message as shown in title.

The full error message is as below.

Thanks,

Tian


AssertionError Traceback (most recent call last) Cell In[5], line 24 12 wav2vec = HubertWithKmeans( 13 checkpoint_path = f'./{hubert_ckpt}', 14 kmeans_path = f'./{hubert_quantizer}' 15 ) 17 semantic_transformer = SemanticTransformer( 18 num_semantic_tokens = wav2vec.codebook_size, 19 dim = 1024, 20 depth = 6 21 )# .cuda() ---> 24 trainer = SemanticTransformerTrainer( 25 transformer = semantic_transformer, 26 wav2vec = wav2vec, 27 folder = dataset_folder, 28 batch_size = 1, 29 data_max_length = 320 * 32, 30 num_train_steps = 1 31 ) 33 trainer.train()

File <@beartype(audiolm_pytorch.trainer.SemanticTransformerTrainer.init) at 0x7f2258c97670>:111, in init(beartype_func, __beartype_conf, beartype_get_violation, beartype_object_139785495343936, beartype_object_205596576, beartype_object_139785495313344, beartype_object_9748320, __beartype_object_115133280, *args, **kwargs)

File /home/venv/lib/python3.9/site-packages/audiolm_pytorch/trainer.py:579, in SemanticTransformerTrainer.init(self, wav2vec, transformer, num_train_steps, batch_size, audio_conditioner, dataset, data_max_length, data_max_length_seconds, folder, lr, grad_accum_every, wd, max_grad_norm, valid_frac, random_split_seed, save_results_every, save_model_every, results_folder, accelerate_kwargs, force_clear_prev_results) 553 @beartype 554 def init( 555 self, (...) 576 force_clear_prev_results = None 577 ): 578 super().init() --> 579 check_one_trainer() 581 self.accelerator = Accelerator(**accelerate_kwargs) 583 self.wav2vec = wav2vec

File /home/venv/lib/python3.9/site-packages/audiolm_pytorch/trainer.py:59, in check_one_trainer() 57 def check_one_trainer(): 58 global ONE_TRAINER_INSTANTIATED ---> 59 assert not ONE_TRAINER_INSTANTIATED, 'only one Trainer can be instantiated at a time for training' 60 ONE_TRAINER_INSTANTIATED = True

AssertionError: only one Trainer can be instantiated at a time for training

LWprogramming commented 1 year ago

211