pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.67k stars 22.58k forks source link

JIT does not batch linear layers in an ensemble #18157

Open omry opened 5 years ago

omry commented 5 years ago

I am trying to speed up a case where I have an ensemble identically structured models. Tracing my model and running a jitted version results in some speed up, but since the layers are small it still does not achieve good utilization of the GPU.

if torch.bmm were used to batch compute the mm for layer[0] across the K models, then layer[1] across the K models etc it would help a lot.

See code sample.

Code sample: http://paste.ubuntu.com/p/kSdJwDsQqB/

bmm is mentioned as a future optimization here. If torch.bmm was used, I would expect to get a significant speedup.

Environment

PyTorch version: 1.0.0.dev20190318 Is debug build: No CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609 CMake version: version 3.5.1

Python version: 3.6 Is CUDA available: Yes CUDA runtime version: Could not collect GPU models and configuration: GPU 0: Quadro P5000 GPU 1: Quadro P5000

Nvidia driver version: 418.39 cuDNN version: Could not collect

Versions of relevant libraries: [pip] numpy==1.15.4 [pip] torch==1.0.0.dev20190318 [pip] torchfile==0.1.0 [conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl_fft 1.0.6 py36hd81dba3_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
[conda] pytorch-nightly 1.0.0.dev20190318 py3.6_cuda10.0.130_cudnn7.4.2_0 pytorch [conda] torchfile 0.1.0

cc @suo

omry commented 5 years ago

cc @soumith

suo commented 5 years ago

Passing to @wanchaol to take a look at the batchmm pass

omry commented 5 years ago

@wanchaol , did you have a chance to take a look?

wanchaol commented 5 years ago

@omry yes I was looking on that, it seems that even if we do the bmm optimization, it might not be beneficial for your use case since it requires lhs/rhs be the same. In your case the RHSs are different across each model. I am looking into alternative ways to batch your case if possible.

omry commented 5 years ago

@wanchaol, the low level torch.bmm() in itself does not require it right?

I think a speedup can came come from reducing the the number of kernel launches and making the jobs bigger.

soumith commented 5 years ago

low-level torch.bmm requires all matmuls in the batch matmul to be of same size (the low-low-level cublas API doesn't require this though)

omry commented 5 years ago

@soumith, I think by different @wanchaol means that they have different parameters in my case, not that they have different sizes.

soumith commented 5 years ago

yea, this is also a limitation of torch.bmm. The cublas bmm directly takes an array of pointers.

One way to write out the optimization is to torch.cat + torch.bmm + torch.narrow, but I think the cat / narrow calls will eat all the speedup

omry commented 5 years ago

Would converting K MxN matrices at a level to 1 KxMxN matrix be possible when jitting instead of delaying to forward time?

wanchaol commented 5 years ago

So one thing I noticed is that JIT do batch your code when your ensemble size is large enough. When change ensemble size to something like 10, you can look into the graph with a big prim::BatchMMSide op which batch all your nodes with the same LHS, you can see a graph here, this is because we have a threshold on triggering it in here.

The reason why we don't batch mm the case where lhs/rhs are different, I think mainly because batching in this form might corrupt the graph dag structure, so we only do the ones that definitely preserve the semantics.

omry commented 5 years ago

Unclear to me why the threshold for the decision being 10 somehow saves us from corrupting the dag structure (and if it's the case something here is fragile).

A few questions:

  1. With the dag structure in my example, I can imagine why it would make sense to have one bmm per layer across the ensemble. You can't do it all at once because of order dependencies. With this in mind, that batching that you find still does not seem to be the right kind of batching, I would expect to see more than one bmm per layer (in this case, 4 bmms), I only see a single prim::MMBatchSide in the graph.

  2. what is the speedup for the jitted model if the ensemble size is above the threshold? (say for 10)