Open omry opened 5 years ago
cc @soumith
Passing to @wanchaol to take a look at the batchmm pass
@wanchaol , did you have a chance to take a look?
@omry yes I was looking on that, it seems that even if we do the bmm optimization, it might not be beneficial for your use case since it requires lhs/rhs be the same. In your case the RHSs are different across each model. I am looking into alternative ways to batch your case if possible.
@wanchaol, the low level torch.bmm() in itself does not require it right?
I think a speedup can came come from reducing the the number of kernel launches and making the jobs bigger.
low-level torch.bmm requires all matmuls in the batch matmul to be of same size (the low-low-level cublas API doesn't require this though)
@soumith, I think by different @wanchaol means that they have different parameters in my case, not that they have different sizes.
yea, this is also a limitation of torch.bmm. The cublas bmm directly takes an array of pointers.
One way to write out the optimization is to torch.cat + torch.bmm + torch.narrow
, but I think the cat / narrow calls will eat all the speedup
Would converting K MxN matrices at a level to 1 KxMxN matrix be possible when jitting instead of delaying to forward time?
So one thing I noticed is that JIT do batch your code when your ensemble size is large enough. When change ensemble size to something like 10, you can look into the graph with a big prim::BatchMMSide
op which batch all your nodes with the same LHS, you can see a graph here, this is because we have a threshold on triggering it in here.
The reason why we don't batch mm the case where lhs/rhs are different, I think mainly because batching in this form might corrupt the graph dag structure, so we only do the ones that definitely preserve the semantics.
Unclear to me why the threshold for the decision being 10 somehow saves us from corrupting the dag structure (and if it's the case something here is fragile).
A few questions:
With the dag structure in my example, I can imagine why it would make sense to have one bmm per layer across the ensemble. You can't do it all at once because of order dependencies. With this in mind, that batching that you find still does not seem to be the right kind of batching, I would expect to see more than one bmm per layer (in this case, 4 bmms), I only see a single prim::MMBatchSide in the graph.
what is the speedup for the jitted model if the ensemble size is above the threshold? (say for 10)
I am trying to speed up a case where I have an ensemble identically structured models. Tracing my model and running a jitted version results in some speed up, but since the layers are small it still does not achieve good utilization of the GPU.
if torch.bmm were used to batch compute the mm for layer[0] across the K models, then layer[1] across the K models etc it would help a lot.
See code sample.
Code sample: http://paste.ubuntu.com/p/kSdJwDsQqB/
bmm is mentioned as a future optimization here. If torch.bmm was used, I would expect to get a significant speedup.
Environment
PyTorch version: 1.0.0.dev20190318 Is debug build: No CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 16.04.6 LTS GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609 CMake version: version 3.5.1
Python version: 3.6 Is CUDA available: Yes CUDA runtime version: Could not collect GPU models and configuration: GPU 0: Quadro P5000 GPU 1: Quadro P5000
Nvidia driver version: 418.39 cuDNN version: Could not collect
Versions of relevant libraries: [pip] numpy==1.15.4 [pip] torch==1.0.0.dev20190318 [pip] torchfile==0.1.0 [conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl_fft 1.0.6 py36hd81dba3_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
[conda] pytorch-nightly 1.0.0.dev20190318 py3.6_cuda10.0.130_cudnn7.4.2_0 pytorch [conda] torchfile 0.1.0
cc @suo