asteroid-team / torch-audiomentations

Fast audio data augmentation in PyTorch. Inspired by audiomentations. Useful for deep learning.
MIT License
969 stars 88 forks source link

Device Conversion Bug #174

Closed nuniz closed 8 months ago

nuniz commented 8 months ago

Hi, I found a bug when training on a machine with multiple GPUs,

The error:

File "/home/user/microservices/algorithms/ds/augmentations/ops.py", line 59, in forward_transform
  batch = transform(batch)
File "/home/user/virtualenvs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/home/user/microservices/algorithms/ds/augmentations/ops.py", line 33, in forward
  transformed = self.forward_transform(*to_transform, **kwargs)
File "/home/user/microservices/algorithms/ds/augmentations/ops.py", line 220, in forward_transform
  return {'data': self.transform(data)}
File "/home/user/virtualenvs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/home/user/microservices/algorithms/ds/augmentations/ops.py", line 82, in forward
  return super().forward(samples, sample_rate=sample_rate)
File "/home/user/virtualenvs/dev/lib/python3.8/site-packages/torch_audiomentations/core/composition.py", line 120, in forward
  inputs = self.transforms[i](**inputs)
File "/home/user/virtualenvs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "/home/user/virtualenvs/dev/lib/python3.8/site-packages/torch_audiomentations/core/transforms_interface.py", line 322, in forward
  selected_samples = cloned_samples[
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

The bug: https://github.com/asteroid-team/torch-audiomentations/blob/9baf5c516a44651025bd7e8d8ead35888b58bbdc/torch_audiomentations/core/transforms_interface.py#L252

The line should be: ).to(dtype=torch.bool, device=samples.device)

I can open a PR to fix this.

Best, Asaf