KinWaiCheuk / nnAudio

Audio processing by using pytorch 1D convolution network
MIT License
1.01k stars 89 forks source link

Data Parallelism support #31

Closed JackFurby closed 4 years ago

JackFurby commented 4 years ago

I am trying to use this library with multiple GPU's but am getting the following error message:

RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/furby/.local/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/furby/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/furby/Documents/models/mobilenet_v1.py", line 129, in forward
    audioOut = self.forward_audio(audio)
  File "/home/furby/Documents/models/mobilenet_v1.py", line 124, in forward_audio
    return self.audioNet(x)
  File "/home/furby/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/furby/Documents/models/mobilenet_v1.py", line 96, in forward
    x = self.spec_layer(x)
  File "/home/furby/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/furby/.pyenv/versions/3.7.6/lib/python3.7/site-packages/nnAudio/Spectrogram.py", line 681, in forward
    spec = torch.sqrt(conv1d(x, self.wsin, stride=self.stride).pow(2) \
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

I have looked into the point in which the code is stopping which seems to be when the batch is split across multiple GPUs and passed through the model. I believe this is because when I am initialising the model, I am configuring the MelSpectrogram with device=device. For a single GPU or a CPU this is fine, but moving over to multiple GPUs, this is being fixed to just one of the GPUs. I am not sure if the issue lies with my configuration or with the library itself, but I am after a way of having the device set on the fly.

My model implementation is as follows:


class Model(torch.nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        config = dict(
                sr=16000,
                n_fft=400,
                n_mels=64,
                hop_length=160,
                window="hann",
                center=False,
                pad_mode="reflect",
                htk=True,
                fmin=125,
                fmax=7500,
                device=device
        self.spec_layer = Spectrogram.MelSpectrogram(**config)

    def forward(self, x):
        x = self.spec_layer(x)
        x = x.view(x.size(0), 1, x.size(1), x.size(2))
        x = super().forward(x)
        return x
KinWaiCheuk commented 4 years ago

Hi @JackFurby, thanks for raising this issue. Previously, my package did not consider the data parallelism case.

I think I have figured out how to add data parallelism support, but I do not have any experience in doing it, I think I should upload a beta package for you to test out first.

You should remove existing nnAudio installation by pip uninstall nnAudio Then install the beta package by pip install nnAudio==0.0.11b

I did some modifications to Spectrogram.MelSpectrogram only, can you test it out if it works for you before I do a complete update on my package?