lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

Problem with DataParallel #25

Closed ZFTurbo closed 9 months ago

ZFTurbo commented 9 months ago

There is a problem with model if you try to use multiGPU:

Traceback (most recent call last):
  File "train_local_run.py", line 44, in <module>
    train_model(args)
  File "train.py", line 239, in train_model
    loss = model(x, y)
  File "\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "\site-packages\torch\nn\parallel\data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "\site-packages\torch\nn\parallel\data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "\site-packages\torch\nn\parallel\parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "\site-packages\torch\_utils.py", line 644, in reraise
    raise exception
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
  File "\site-packages\torch\nn\parallel\parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "\models\bs_roformer\mel_band_roformer.py", line 437, in forward
    stft_window = self.stft_window_fn(device=self.device)
  File "\models\bs_roformer\mel_band_roformer.py", line 403, in device
    return next(self.parameters()).device
StopIteration

What is the correct way to fix it?

ZFTurbo commented 9 months ago

My current dirty fix is:

    @property
    def device(self):
        # return next(self.parameters()).device
        return 'cuda'
lucidrains commented 9 months ago

@ZFTurbo hmm that's weird, not sure why it can't find any parameters

could you retry with distributed data parallel and see if this issue persists?

lucidrains commented 9 months ago

@ZFTurbo try 0.3.9?

ZFTurbo commented 9 months ago

New code works. Thank you!

lucidrains commented 9 months ago

yup no problem

highly recommend huggingface accelerate btw! it will take away all the pain of doing distributed training