zhvng / open-musiclm

Implementation of MusicLM, a text to music model published by Google Research, with a few modifications.
https://arxiv.org/abs/2301.11325
MIT License
511 stars 59 forks source link

Error when trying to train CLAP #2

Closed Saltb0xApps closed 1 year ago

Saltb0xApps commented 1 year ago

$ python train_clap_rvq.py

found ignored file, skipping
found ignored file, skipping
found ignored file, skipping
found ignored file, skipping
found ignored file, skipping
found ignored file, skipping
training with dataset of 7594 samples and validating with randomly splitted 400 samples
Traceback (most recent call last):
  File "/Users/akhiltolani/Desktop/open-musiclm-main/scripts/train_clap_rvq.py", line 37, in <module>
    trainer.train()
  File "/Users/akhiltolani/Desktop/open-musiclm-main/scripts/../open_musiclm/trainer.py", line 507, in train
    logs = self.train_step()
  File "/Users/akhiltolani/Desktop/open-musiclm-main/scripts/../open_musiclm/trainer.py", line 473, in train_step
    raw_wave_for_clap = next(self.dl_iter)[0]
IndexError: tuple index out of range
zhvng commented 1 year ago

Hey Akhil, thanks for testing out my code! I'm not sure what's happening here but it looks like there might be a problem with loading a sound file from the dataset. What is your batch size? Are you training on fma_small?

Saltb0xApps commented 1 year ago

@zhvng Yes, I'm trying to run train_clap_rvq.py on the fma_large dataset. The batch size and all the other parameters are the same as the default values.

device = 'cuda' if torch.cuda.is_available() else 'cpu'

audio_folder = './data/fma_large'

clap_checkpoint = "./checkpoints/clap-laion-audioset-fusion.pt"
rvq_checkpoint = None   # './checkpoints/clap.rvq.950.pt'
with disable_print():
    clap = create_clap_quantized(device=device, learn_rvq=True, checkpoint_path=clap_checkpoint, rvq_checkpoint_path=rvq_checkpoint).to(device)

trainer = ClapRVQTrainer(
    num_train_steps=1000, 
    batch_size=64,
    accumulate_initial_batch=2,
    audio_conditioner=clap,
    folder=audio_folder,
    results_folder='./results/clap_rvq',
    save_model_every=50,
    save_results_every=25
).to(device)

trainer.train()
Saltb0xApps commented 1 year ago

when I try to print(next(self.dl_iter)) in train_clap_rvq.py line 468, the output is an empty tuple () which is why the trainer is failing. My understanding is for some reason the get_dataloader function in data.py isn't working with fma_large dataset?

Any idea why this could be happening?

Saltb0xApps commented 1 year ago

Note - the same issue happens when I try to train the Hubert model too

audio_folder = './data/fma_large'

print('loading hubert...')
hubert_kmeans = get_hubert_kmeans(model_name='m-a-p/MERT-v0', kmeans_path=None)
trainer = HfHubertKmeansTrainer(
    feature_extraction_num_steps=100,
    feature_extraction_batch_size=64,
    data_max_length=16000*8,
    hubert_kmeans=hubert_kmeans,
    folder=audio_folder,
    results_folder='./results/hubert_kmeans',
).to(device)

trainer.train()

Error -

(base) akhiltolani@Akhils-MacBook-Pro-2 scripts % python train_hubert_kmeans.py
loading hubert...
training on 6400 out of 106574 samples
step 1: extracting features. must wait for this to complete before training kmeans.
0 / 100 steps
Traceback (most recent call last):
  File "/Users/akhiltolani/Desktop/open-musiclm-main 2/scripts/train_hubert_kmeans.py", line 29, in <module>
    trainer.train()
  File "/Users/akhiltolani/Desktop/open-musiclm-main 2/scripts/../open_musiclm/trainer.py", line 591, in train
    features.append(self.extract_hubert_features())
  File "/Users/akhiltolani/Desktop/open-musiclm-main 2/scripts/../open_musiclm/trainer.py", line 575, in extract_hubert_features
    raw_wave = next(self.dl_iter)[0]
IndexError: tuple index out of range
Saltb0xApps commented 1 year ago

The root cause of my issues was using the incorrect version of torch!