fxia22 / stn.pytorch

pytorch version of spatial transformer networks
Other
589 stars 87 forks source link

compatibility with DataParallel #14

Open sniklaus opened 7 years ago

sniklaus commented 7 years ago

Thank you for this implementation! Have you tried using it within a network that is wrapped in a DataParallel in order to make use of multiple graphics cards? I am getting an illegal memory access was encountered error when replacing

with torch.cuda.device(3):
    input1 = input1.cuda()
    input2 = input2.cuda()
    start = time.time()
    out = s(input1, input2)
    print(out.size(), 'time:', time.time() - start)
    start = time.time()
    out.backward(input1.data.cuda())
    print('time:', time.time() - start)

in test.py with

s = torch.nn.DataParallel(s)
if True:
    input1 = input1.cuda()
    input2 = input2.cuda()
    start = time.time()
    out = s(input1, input2)
    print(out.size(), 'time:', time.time() - start)
    start = time.time()
    out.backward(input1.data.cuda())
    print('time:', time.time() - start)

Interestingly, the code works with

export CUDA_VISIBLE_DEVICES="0"

but fails with

export CUDA_VISIBLE_DEVICES="0,1"

I see that you are explicitly setting the CUDA device before executing the kernel, which might be the reason for the illegal memory access. Any ideas? Thank you!