ldzhangyx / instruct-MusicGen

The official implementation of our paper "Instruct-MusicGen: Unlocking Text-to-Music Editing for Music Language Models via Instruction Tuning".
Apache License 2.0
47 stars 3 forks source link

Distributed training doesn't seem to be working #6

Open Saltb0xApps opened 1 week ago

Saltb0xApps commented 1 week ago

Hey! i tried running the training with DDP over 2 GPUs but got this error -

Error executing job with overrides: ['trainer=gpu']
Traceback (most recent call last):
  File "/home/akhil/instruct-MusicGen/src/train.py", line 125, in main
    metric_dict, _ = train(cfg)
  File "/home/akhil/instruct-MusicGen/src/utils/utils.py", line 78, in wrap
    raise ex
  File "/home/akhil/instruct-MusicGen/src/utils/utils.py", line 68, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
  File "/home/akhil/instruct-MusicGen/src/train.py", line 92, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1028, in _run_stage
    self._run_sanity_check()
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1057, in _run_sanity_check
    val_loop.run()
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 410, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 640, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 633, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/home/akhil/instruct-MusicGen/src/models/instructmusicgenadapter_module.py", line 157, in validation_step
    loss, preds, targets = self.model_step(batch)
  File "/home/akhil/instruct-MusicGen/src/models/instructmusicgenadapter_module.py", line 120, in model_step
    loss, preds, y = self.forward(batch)
  File "/home/akhil/instruct-MusicGen/src/models/instructmusicgenadapter_module.py", line 79, in forward
    output = self.model(
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/akhil/instruct-MusicGen/src/models/components/model.py", line 388, in forward
    embed_fn = self.cp_transformer.forward(condition_audio_code=condition_audio_code,
  File "/home/akhil/instruct-MusicGen/src/models/components/model.py", line 289, in forward
    sum_code = sum([self.emb_fn["emb"][i](condition_audio_code[:, i]) for i in range(4)])
  File "/home/akhil/instruct-MusicGen/src/models/components/model.py", line 289, in <listcomp>
    sum_code = sum([self.emb_fn["emb"][i](condition_audio_code[:, i]) for i in range(4)])
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/akhil/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

The training does start, but this error comes up in the first 15-20 seconds. I'd imagine its probably some minor issue with file handling across multiple devices?

ldzhangyx commented 1 week ago

I utilise PyTorch-Lightning, which should be able to automatically resolve the problem. In the meantime, maybe you can check if I hard-coded cuda:0 in any python file.

Saltb0xApps commented 1 week ago

@ldzhangyx No hardcoded cuda:0 calls in any python code. I believe the issue seems to be with instructmusicgenadapter_module.py -> forward function handling of tensors, not the PyTorch-Lightning side of things.