acids-ircam / ddsp_pytorch

Implementation of Differentiable Digital Signal Processing (DDSP) in Pytorch
Apache License 2.0
451 stars 56 forks source link

Training on cuda causes exception #5

Closed fromjupiter closed 3 years ago

fromjupiter commented 4 years ago

I encounter several issues so far while trying to train on cuda. There are easy workarounds but I just want to report them in case others encounter the same issue.

  1. fixed_batch is explicitly put on the device but it's not supported by plot_batch_detailed. Solution is to comment out train.py:line 123. Exception: Traceback (most recent call last): File "./train.py", line 125, in plot_batch_detailed(fixed_batch) File "/datasets/home/home-00/07/907/k1feng/ReverbNN/utils/plot.py", line 25, in plot_batch_detailed axes[b, 0].plot(audio[b].squeeze(0)[:2000]) File "/opt/conda/lib/python3.7/site-packages/matplotlib/axes/_axes.py", line 1668, in plot self.add_line(line) File "/opt/conda/lib/python3.7/site-packages/matplotlib/axes/_base.py", line 1902, in add_line self._update_line_limits(line) File "/opt/conda/lib/python3.7/site-packages/matplotlib/axes/_base.py", line 1924, in _update_line_limits path = line.get_path() File "/opt/conda/lib/python3.7/site-packages/matplotlib/lines.py", line 1027, in get_path self.recache() File "/opt/conda/lib/python3.7/site-packages/matplotlib/lines.py", line 675, in recache y = _to_unmasked_float_array(yconv).ravel() File "/opt/conda/lib/python3.7/site-packages/matplotlib/cbook/init.py", line 1390, in _to_unmasked_float_array return np.asarray(x, float) File "/opt/conda/lib/python3.7/site-packages/numpy/core/numeric.py", line 538, in asarray return array(a, dtype, copy=False, order=order) File "/datasets/home/07/907/k1feng/.local/lib/python3.7/site-packages/torch/tensor.py", line 488, in array return self.numpy().astype(dtype, copy=False) TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

  2. msstft Loss needs to put the Han window on the device as well. (loss.py:37)

  3. msstft Loss needs to put cur_fft on the device (loss.py:47)

shidephen commented 4 years ago

Maybe caused by this window Tensor

https://github.com/acids-ircam/ddsp_pytorch/blob/a87b2010ed7024b05bcdc4160a6e5f3273cb0ea8/code/ddsp/loss.py#L37

caillonantoine commented 3 years ago

The code has been entirely re-written, everything works now !