I'm trying to use torch.nn.DataParallel on lpips network,but then it gives me error
But when I modified the 100th line in the stylegan2-pytorch.lpips.dist_model.py from self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) to self.net = torch.nn.DataParallel(self.net), the error is removed.
is this right solution?
the error code when I used the original code is as below
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/nas125/SeungjunLee/Projects/stylegan2/lpips/networks_basic.py", line 78, in forward
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
File "/mnt/nas125/SeungjunLee/Projects/stylegan2/lpips/networks_basic.py", line 78, in <listcomp>
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
input = module(input)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 419, in forward
return self._conv_forward(input, self.weight)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 415, in _conv_forward
return F.conv2d(input, weight, self.bias, self.stride,
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)
Hi,
I'm trying to use torch.nn.DataParallel on lpips network,but then it gives me error
But when I modified the 100th line in the
stylegan2-pytorch.lpips.dist_model.py
fromself.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
toself.net = torch.nn.DataParallel(self.net)
, the error is removed.is this right solution?
the error code when I used the original code is as below