Closed moscow25 closed 6 years ago
I believe it’s a CuPy thing. You’d probably have more success running with DistributedDataParallel since each process has a totally separate CUDA handle, but maybe there’s a way to fix it here? What's the error?
Yep CuPy thing for sure. Here's my stack trace:
Code is very simple. torch.nn.Embedding
on a list of IDs, which get fed into single-layer QRNN. Works fine in CPU mode and in GPU mode with single machine and DataParallel off, just .cuda()
the model.
Will try DistributedDataParallel
in the meantime. Thanks!
File "main.py", line 510, in train
output = model(...)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 60, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 70, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
raise output
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 42, in _worker
output = module(*input, **kwargs)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
File "/code/model_rnn.py", line 171, in forward
output, hidden = self.rnn_jr(j_vec)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torchqrnn/qrnn.py", line 160, in forward
input, hn = layer(input, None if hidden is None else hidden[i])
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torchqrnn/qrnn.py", line 95, in forward
C = ForgetMult()(F, Z, hidden, use_cuda=self.use_cuda)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torchqrnn/forget_mult.py", line 172, in forward
if hidden_init is None: return GPUForgetMult()(f, x) if use_cuda else CPUForgetMult()(f, x)
File "/opt/conda/envs/pytorch-py35/lib/python3.5/site-packages/torchqrnn/forget_mult.py", line 125, in forward
self.forget_mult(grid=grid, block=(grid_hidden_size, 1), args=[result.data_ptr(), f.data_ptr(), x.data_ptr(), seq_size, batch_size, hidden_size], stream=self.stream)
File "cupy/cuda/function.pyx", line 141, in cupy.cuda.function.Function.__call__
File "cupy/cuda/function.pyx", line 123, in cupy.cuda.function._launch
File "cupy/cuda/driver.pyx", line 169, in cupy.cuda.driver.launchKernel
File "cupy/cuda/driver.pyx", line 69, in cupy.cuda.driver.check_status
cupy.cuda.driver.CUDADriverError: CUDA_ERROR_INVALID_HANDLE: invalid resource handle
srun: error: hsw224: task 0: Exited with exit code 1
+1 for @jekbradbury's DistributedDataParallel
suggestion for now. Unfortunately I'm swamped right now but I'll look at this after ICLR deadline.
My guess as to the line that needs to be updated for such support is: https://github.com/salesforce/pytorch-qrnn/blob/3aa5e72fe263b7c1b6b2194b6af078f8eb4efb92/torchqrnn/forget_mult.py#L112
For now ForgetMult
assumes that the CUDA stream it sees when initially compiling the CUDA kernel is the correct one and is tied to a class variable. I did this as I wasn't sure how long setting up the stream
would take re: performance and I was only using it in a single GPU setting.
Setting self.stream
to the correct value either when constructing ForgetMult
or in the forward pass (where presumably the current CUDA stream is "correct") could be the fix? Wait ... Hmm ... We might need to consider what's set on the GPU during the compile()
step too.
Actually, scratch that, I have a version that appears to be working for DataParallel
. @moscow25, interested in playing with it to see if I've missed anything?
In master of https://github.com/Smerity/pytorch-qrnn where I've updated ForgetMult
and also included under examples
a test of multiple GPU using DataParallel
.
For quite large matrices and sequences (where I didn't want to go much larger as the single GPU runs out of memory):
Single
Time: 47.81051325798035
Two GPUs
Time: 31.48588228225708
Difference:
Variable containing:
0
[torch.cuda.FloatTensor of size 1 (GPU 0)]
where "difference" is the total sum difference between the result from the single GPU and two GPU runs.
Note: the speed-up could be even better (the single GPU sits at 100% utilization but the two GPUs sit at ~70% utilization when the batch is split) but then the experiment would take forever on a single GPU / run out of memory.
Sorry for the late reply @Smerity -- pulled the changes in your branch is it works on my code! Running on multi-GPU no with DataParallel
on. Thanks for the tutorial on how to handle multi-GPU with the .compile()
. Very fast.
Glad you could test it - and huzzah that it's working for you! ^_^ Any vague approximation note on how much faster 4xGPU QRNN is than 4xLSTM? Eh, I'll settle for "very fast" anyway - really glad it's working for you ^_^
I'll merge this in now but will need to update the README to state completion either if I get ahead of my paper deadline for this Friday (lol) or early next week.
Will update numbers here when I get them for 4x GPU comparison with LSTM. Had trouble getting a P100 scheduled. I also need to get QRNN working for 2+ layers -- some size mismatch for that one so not "drop in replacement" for LSTM, but I'm sure I can fix it.
Could you guys get it to work with
torch.nn.DataParallel(model).cuda()
? I could not, but perhaps did not try hard enough. Can't tell if it's a wrong-GPU problem, or CuPy won't support it.Runs pretty fast on 1x GPUs though. A bit faster than 4x GPUs for vanilla LSTM, but not by much without scaling to multiple GPUs...