lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.32k stars 249 forks source link

Question: Are there any work arounds for using DeepSpeed for multi-gpu training #240

Open rgxb2807 opened 8 months ago

rgxb2807 commented 8 months ago

Hi, thanks so much for the incredible repo.

I've been able to successfully train using a multi-gpu setup and accelerate. I'm wondering if it's possible to use Microsoft's DeepSpeed with the accelerate library.

When I enable Stage 1 or Stage 2 in the accelerate configuration I get the following error. Any ideas for workarounds?

RuntimeError: Tensor must have a storage_offset divisible by 2

I'm training SoundStream with the data from https://us.openslr.org/resources/12/dev-clean.tar.gz

Traceback (most recent call last):
  File "/audio/soundstream_train.py", line 43, in <module>
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/trainer.py", line 572, in train
    logs = self.train_step()
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/trainer.py", line 441, in train_step
    loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1769, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 852, in forward
    (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 285, in forward
    x = layer(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 308, in forward
    return self.fn(x, **kwargs) + x
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 197, in forward
    weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
RuntimeError: Tensor must have a storage_offset divisible by 2
Traceback (most recent call last):
  File "/audio/soundstream_train.py", line 43, in <module>
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/trainer.py", line 572, in train
    logs = self.train_step()
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/trainer.py", line 441, in train_step
    loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1769, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 852, in forward
    (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 285, in forward
    x = layer(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 308, in forward
    return self.fn(x, **kwargs) + x
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/audiolm_pytorch/soundstream.py", line 197, in forward
    weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
RuntimeError: Tensor must have a storage_offset divisible by 2
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 440) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 926, in launch_command
    deepspeed_launcher(args)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 671, in deepspeed_launcher
    distrib_run.run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
soundstream_train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-10-18_22:04:28
  host      : c6e49f99abae
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 441)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-10-18_22:04:28
  host      : c6e49f99abae
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 440)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
lucidrains commented 8 months ago

@rgxb2807 i was hoping accelerate would just work lol

lucidrains commented 8 months ago

@rgxb2807 maybe you can raise an issue at the accelerate repository

rgxb2807 commented 8 months ago

@rgxb2807 i was hoping accelerate would just work lol

It does! and I'm grateful for it.

@rgxb2807 maybe you can raise an issue at the accelerate repository

I'll do some digging and raise an issue where appropriate.

rgxb2807 commented 8 months ago

Raised the issue with the accelerate team here #2106