fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
985 stars 148 forks source link

DataParallel on GPUs not supported? #2

Closed HopLee6 closed 5 years ago

HopLee6 commented 5 years ago

I've tried your implementations using multi-GPU and failed.

the code is as follows

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

xfm = DWTForward(J=3, wave='db3', mode='periodization')

if torch.cuda.device_count() > 1:
    xfm = nn.DataParallel(xfm)
xfm.to(device)

X = torch.randn(10,3,256,256).to(device)
Yl, Yh = xfm(X)

the error is reported as follows

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    Yl, Yh = xfm(X)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 123, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 133, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 77, in parallel_apply
    raise output
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 53, in _worker
    output = module(*input, **kwargs)
  File "/home/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/opt/lhp/DWT/pytorch_wavelets/dwt/transform2d.py", line 77, in forward
    y = lowlevel.afb2d(ll, self.h, self.mode)
  File "/home/opt/lhp/DWT/pytorch_wavelets/dwt/lowlevel.py", line 229, in afb2d
    lohi = afb1d(x, h0_row, h1_row, mode=mode, dim=3)
  File "/home/opt/lhp/DWT/pytorch_wavelets/dwt/lowlevel.py", line 107, in afb1d
    lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
RuntimeError: Assertion `THCTensor_(checkGPU)(state, 3, input, output, weight)' failed. Some of weight/gradient/input tensors are located on different GPUs. Please move them to a single one. at /pytorch/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu:16
fbcotter commented 5 years ago

I also get the same error when using DataParallel, good find. This is something I haven't checked for before. Will try fix it now.

HopLee6 commented 5 years ago

Actually, the following modification works: `self.h = [t.to(ll.device) for t in self.h]

y = lowlevel.afb2d(ll, self.h, self.mode)`

fbcotter commented 5 years ago

I didn't see your code, but I've done something very similar. Thanks!

JiaWang0704 commented 2 years ago

Actually, the following modification works: `self.h = [t.to(ll.device) for t in self.h]

y = lowlevel.afb2d(ll, self.h, self.mode)`

I don't understand the meaning of these two lines of code. Can you give a more detailed explanation? Thank u very much!

fbcotter commented 2 years ago

I think this code was just casting all the bandpass to the device. Did you have an issue with DataParallel?