CNugteren / CLBlast

Tuned OpenCL BLAS
Apache License 2.0
1.05k stars 205 forks source link

Implement batched BLAS routines #95

Closed CNugteren closed 6 years ago

CNugteren commented 8 years ago

Batched operations involve performing many small linear-algebra operations, such as GEMV or GEMM. In particular batched GEMM has become increasingly popular due to deep learning. More parallelism can be exploited when making a single batched BLAS call compared to multiple regular BLAS calls on small matrices. NVIDIA's cuBLAS for example has a batched GEMM interface.

Two potentially related papers are:

blueberry commented 8 years ago

Unfortunately, I cannot help by implementing this, but I just want to give thumbs up for this functionality. Currently, I resort to writing my own special-case kernels for such cases.

bhack commented 8 years ago

We could also add this https://github.com/andravin/wincnn/blob/master/README.md. The paper it is at the end of the page.

gcp commented 8 years ago

It's a bit outside the scope of a BLAS library, but I'll be the first to admit that having a fused/implicit GEMM (not materializing the actual matrix via im2col) for convolutions, or even better, winograd style convolutions would be enormously beneficial for my use case. More so than batched ones, because I'm doing inference and latency matters.

I'm not using CLBlast at the moment because a direct convolution outperforms im2col + CLBlast/clBLAS.

bhack commented 8 years ago

Yes but strictly talking also the batched gemm it is a bit outside the BLAS API. There are some efforts but still not finally standardized. See https://github.com/sdrelton/bblas_api_test/blob/master/README.md

bhack commented 8 years ago

And http://www.netlib.org/utk/people/JackDongarra/WEB-PAGES/Batched-BLAS-2016/

bhack commented 8 years ago

/cc @naibaf7

naibaf7 commented 8 years ago

@gcp LibDNN has these direct (fused/implicit GEMM) kernels, if you want to check it out: https://github.com/naibaf7/libdnn There are a few kernel variants (atomic and deterministic) and tuning parameters to try out.

CNugteren commented 8 years ago

Thanks all for the feedback and links! I opened this issue mostly to collect some relevant work and discuss what else could be useful to add. Indeed, it is a bit outside of BLAS, but so are some other functions currently already in CLBlast.

I'll work on batched GEMM when some of the other open issues are resolved. Afterwards we can have a discussion about other related functionality such as Winograd.

blueberry commented 8 years ago

...or maybe some parts of LAPACK would be good candidates to support? ATLAS implements some parts of LAPACK directly, if I remember well...

naibaf7 commented 8 years ago

@gcp @CNugteren At least for DNN applications, we have to keep in mind that batched GEMM is not a sustainable approach due to the large intermediate column buffers (after im2col) that a batched im2col would need to generate in order to use a batched GEMM afterwards.

For applications including convolutions, that are not DNN, direct convolution or FFT may still be the best approach. This is due to often having 3 (RGB) channels in and 3 (RGB) channels out, often processed separately (in 3 groups of 1 channel), so this often has kernels that can be cached completely, as opposed to DNN convolution kernels with k_w * k_h * f_in * f_out weights.

I've tried batched im2col + batched GEMM before, and it quickly fills up the GPU memory. Another limitation that will hit hard here is the maximum buffer size on certain OpenCL devices, such as AMDs cards. These have segmented memory, and at most 1/4th of the total capacity can be allocated in a single buffer. With a typical 4 GB card, this limit is 1 GB, and unrolled im2col fills this very quickly with larger batch sizes (see: CL_DEVICE_MAX_MEM_ALLOC_SIZE). This also means more total memory transfers from/to global memory, whereas a fused im2col+GEMM kernel can leverage local memory better, and use swizzle instructions (AMD Polaris, Intel Broadwell/Skylake) to exchange data within workgroups, subgroups or wavefronts. (ref: https://www.khronos.org/registry/cl/extensions/intel/cl_intel_subgroups.txt, http://gpuopen.com/amd-gcn-assembly-cross-lane-operations/)

Of course, batched GEMM has many useful applications other than that, especially when a structure of matrices is already stored in a single buffer.

blueberry commented 8 years ago

Hmmm, I think I created buffers larger than 2GB without problems on my AMD 290X 4GB. I think I hit the limit at 3+ GB...

naibaf7 commented 8 years ago

@blueberry Yes, some GPUs and drivers allow it, also depending on what GPU_MAX_ALLOC_PERCENT is (an environment variable) which will in turn modify CL_DEVICE_MAX_MEM_ALLOC_SIZE. It's still not pretty and can degrade performance though. What's your CL_DEVICE_MAX_MEM_ALLOC_SIZE for the R9 290X?

bhack commented 8 years ago

@naibaf7 Have you benchmarked https://bitbucket.org/multicoreware/hccaffe/src/0b7963d45e7d552695ddd38e4f847ac5794e4d07/src/caffe/layers/conv_layer_hcc.cpp?at=HCC&fileviewer=file-view-default?

blueberry commented 8 years ago

@naibaf7 2595886080

naibaf7 commented 8 years ago

@bhack Yes this is basically the same as standard Caffe convolution, just a different compiler backend. Does not affect performance.

bhack commented 8 years ago

Yes I know that is not a direct convolution. But are GPU performance decent respect the same opencl kernel?

naibaf7 commented 8 years ago

@bhack As far as I can tell yes, it only runs on Fiji and Polaris hardware though (the latter experimental, and I've only been able to get it working on Ubuntu). Depends on HcBLAS, see here: http://gpuopen.com/compute-product/hcblas/ Performance similar to clBLAS.

bhack commented 8 years ago

Yes hcblas has batched gemm. But I suppose that we have the same memory problems with im2col_hcc_kernal.

naibaf7 commented 8 years ago

@bhack Yup. The same approach has been tried by AMD before, in OpenCL; see here: https://github.com/amd/OpenCL-caffe/blob/stable/src/caffe/layers/conv_layer.cpp

bhack commented 8 years ago

So they are habitual cause hccaffe started after the old AMD fork.

gcp commented 8 years ago

@naibaf7 I'm aware of libDNN, but the ViennaCL dependency makes it unattractive for me. I've looked at the code and it was honestly not clear to me why it was required, as there's low level OpenCL code in the OpenCL paths anyway.

I agree that for DCNN use cases fusing is much more important than batching. If you expand the data you're just not going to win on performance, GPU are too memory limited for that.

naibaf7 commented 8 years ago

@gcp Yes the ViennaCL part is there because I use it extensively in Caffe and LibDNN was "pulled out" of Caffe on short notice for integration with tiny-cnn (@bhack). If you wish, I can remove this dependency quite handily. I just don't have that much time for it at the moment, so you'd have to be a bit patient... Usually ViennaCL is not a very problematic library to handle, and it's header only though. (ViennaCL is in Caffe for the fallback-BLAS and because it fixes a lot of "difficult" corner-cases that some OpenCL implementations pose as problem - such as device initialization on computers with nVidia Optimus)

bhack commented 8 years ago

Everything that needs tuning needs to compile programs and launch kernels for timing. With a large crowdtuning initiative probably we could "download" directly the right kernel for our device/dimensionality.

naibaf7 commented 8 years ago

@bhack This is true. @dividiti (http://www.dividiti.com/) is working on such infrastructures I believe, called CK (collective knowledge), https://github.com/dividiti/ck-caffe, https://github.com/ctuning/ck

bhack commented 8 years ago

I already know it and I like it. But it needs a more integrated (quite default) support by frameworks and convolution developers.

bhack commented 8 years ago

@naibaf7 Are they adding support for Greentea https://github.com/dividiti/ck-caffe/commit/862ec9d920cc2daf41b23d36ae1324b8ad9b6e77?

naibaf7 commented 8 years ago

@bhack Yes, it's now based on OpenCL Caffe :)

blueberry commented 7 years ago

@CNugteren any news regarding batched operations? I understand that you're still working on more important things (we all owe you a huge thanks for those!), but do you at least have it on the radar for this year, or this summer even?

CNugteren commented 7 years ago

Yes, I hope to work on it rather sooner than later. Right now I'm still working on the last missing routines (TRSM and TRSV) to make integration with ArrayFire possible. The first implementations are done, but they are still a bit slow. However, optimising those routines is perhaps not too urgent for regular CLBlast users now (haven't seen any TRSM/TRSV feature requests even), although they seem to be quite important for ArrayFire.

Next on the list is adding support for #130 I guess, but hopefully that will be rather simple to implement. Afterwards it is time for batched GEMM I believe.

Any opinions about the API? Should it just be like regular GEMM with a batch-count? Should it be like cuBLAS? Or like Intel's CBLAS extension? Or like a Netlib proposal?

Here is a starting point for discussion:

template <typename T>
StatusCode
GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
            const size_t m, const size_t n, const size_t k,
            const T *alphas,
            const cl_mem *a_buffers, const size_t a_offset, const size_t a_ld,
            const cl_mem *b_buffers, const size_t b_offset, const size_t b_ld,
            const T *betas,
            cl_mem *c_buffers, const size_t c_offset, const size_t c_ld,
            const size_t batch_count,
            cl_command_queue* queue, cl_event* event = nullptr);
blueberry commented 7 years ago

@CNugteren Hey, that's really great news! Thank you.

As for the actual API, I cannot suggest anything regarding GEMM, since I haven't used any of the existing solutions. I am planning to implement tensors for neanderthal at some point, so I am just starting to casually gather information, so I understand that batched operations are important in that case. I believe that cuBLAS is an obvious target, since most people using matrices on the GPU probably have to adapt to what they do to a large extent, but I believe that you have far more information about that than I can give you.

I may use an opportunity to suggest something not mentioned before: I think that even more than batched GEMM, we need batched versions of some BLAS-1 (and possibly BLAS-2) operations. For example, a common use case is copying a matrix, or adding two matrices. Currently, I do this by calling a straightforward COPY or AXPY when the matrix is not strided, but when it is, I have to call COPY or AXPY for each column or row. That way, I have 2 overheads: kernel launch, and also kernel inefficiency (since smallish arrays cannot always make all cores busy). With batched BLAS-1 operations, I would be able to call those operations with one batched call, and avoid both overheads.

I guess that implementing batched BLAS-1 would also be a good testing ground for (I guess) more demanding task of batched BLAS-3.

Is this a good suggestion, or am I missing something?

CNugteren commented 7 years ago

OK, that makes sense. Perhaps I'll will start with AxpyBatched.

CNugteren commented 7 years ago

I started today by implemented the required infrastructure on the library side and on the test side. That is completed now: I can now successfully run and test a naive version of AxpyBatched (which simply calls Axpy in a loop). Next up is actually implementing an optimised AxpyBatched routine. The API is currently like this:

template <typename T>
StatusCode AxpyBatched(const size_t n,
                       const T *alphas,
                       const cl_mem *x_buffers, const size_t x_inc,
                       cl_mem *y_buffers, const size_t y_inc,
                       const size_t batch_count,
                       cl_command_queue* queue, cl_event* event)

As you can see I dropped the offset arguments, similar to the CUDA API. Let me know if there are remarks on the API.

blueberry commented 7 years ago

@CNugteren I believe that offset is important, at least for the case when those arrays are really columns of matrices, taking into account submatrices etc. If compatibility with CUDA api is preferred, then maybe the best option is to provide the version with offset, and additional, CUDA-compatible version without it that simply delegates to the first with offsets set to zeroes.

CNugteren commented 7 years ago

@blueberry No, there is no need for compatibility with the CUDA API. But in your use-case you would then need to provide an array of offsets, right? Because just one offset doesn't make a lot of sense, I guess? So in your use-case you could have a single buffer (pass on multiple copies of the same cl_mem argument) and then provide an array of offsets to index sub-buffers within that array? Yes I guess that could make sense as well. So both an array of cl_mem and an array of offsets.

blueberry commented 7 years ago

@CNugteren I don't believe so. Remember, the typical use case is that those mini-arrays have some regular structure. So, maybe the first offset is different, but all others should be the same. Depending on what we want to support, let's say that there is a need for:

Option A (one cl-mem and many offsets)

1) The first offset into the array (to know where to start) 2) The offset increment to calculate each subarray 3) The element increment (x_inc) 4) maybe another argument to control the number of times option 2 and 3 would be applied (sometimes we do not want to process the whole uberarray, as in the case of submatrices...

Option B is cl-mem for each array - more difficult, since OpenCL does not allow arbitrary clSubBuffer lengths!

Basically, the only universal way of selecting a subarray to work with is having an offset... Those "batched" routines often operate on parts of the same array, but the selection of parts is a bit more complex than one offset and one x_inc...

CNugteren commented 7 years ago

OK, so what about an array of cl_mem objects and an array of offsets? That way we can satisfy both use-cases:

  1. Supply an array of individual cl_mem objects and an array of zero's for the offsets.
  2. Supply an array filled with a copy of the same cl_mem object and an array with offsets indicating the starting positions (e.g. [0, 100, 200, 300]).

Any remarks on this proposal?

blueberry commented 7 years ago

I am concerned that preparing such arrays (and perhaps later consuming them) would incur significant performance overhead.

For example, when there are hundreds of those pieces, for each batched call, I'd have to construct (and forward to the method) that 0-array. Also, I'd have to copy the cl_mem object multiple times, and then call the kernel. That's a considerable number of expensive calls that go through the OpenCL driver...

On the other hand, I am not familiar with the general use case. Theoreticaly, there might be a need for computing hundreds of heterogeneous arrays, but is there a practical case that would require that?

The only case that I practically encountered myself was the one that I described, when there are multiple pieces, but with a regular structure.

CNugteren commented 7 years ago

Thanks for your feedback!

Actually, I also don't know what the typical use-case is, but I based my first API design on the CUDA API. But there things are actually a bit different, so indeed for OpenCL we might want to just pass a single cl_mem object and then an array of offsets. One very practical reason I just came across: you can't simply copy an array of cl_mem objects to the GPU and then retrieve the individual pointers, see also here.

So the new API will then be:

template <typename T>
StatusCode AxpyBatched(const size_t n,
                       const T *alphas,
                       const cl_mem x_buffer, const size_t *x_offsets, const size_t x_inc,
                       cl_mem y_buffer, const size_t *y_offsets, const size_t y_inc,
                       const size_t batch_count,
                       cl_command_queue* queue, cl_event* event)
blueberry commented 7 years ago

Can you also support something like this as a special case?

template <typename T>
StatusCode AxpyBatched(const size_t n,
                       const T alpha,
                       const cl_mem x_buffer, const size_t x_offset, const size_t x_offset_inc, const size_t x_inc,
                       cl_mem y_buffer, const size_t y_offset, const size_t y_offset_inc, const size_t y_inc,
                       const size_t batch_count,
                       cl_command_queue* queue, cl_event* event)

On the other hand, that could be also covered by providing a special GE_Axpy, TR_Axpy etc methods. I am not sure what is the better approach...

CNugteren commented 7 years ago

But that would be supported by my proposal, right? You would just have to create an array with [x_offset, x_offset_inc, 2 * x_offset_inc, 3 * x_offset_inc, ...]. I'll leave that exercise to the user, the API shouldn't grow unnecessarily big :-)

blueberry commented 7 years ago

That is supported, and I am not complaining about that.

I am concerned that that array could be several thousands of elements long, and it would have to be created for each method call. That could be a significant overhead, both in array creation and in data transfer to the kernels. I guess that's in microseconds, but I could be wrong.

What I'm wondering is whether there should be a method version that would just get 2 numbers, and then compute j * x_offset_inc inside the kernel itself?

blueberry commented 7 years ago

Just a short note that Intel MKL's extensions look like a good source for finding what is needed and useful in practice regarding to those Matrix Blas 1 operations (I don't know about batched stuff):

https://software.intel.com/en-us/node/520857

CNugteren commented 7 years ago

A first naive (but fast!) version of batched AXPY is implemented. Here is an example benchmark:

./clblast_client_xaxpybatched -n 8192 -num_steps 16 -step 8192 -batch_num 200
                                          | <--  CLBlast  --> | <--  clBLAS    --> |
        n;     incx;     incy;batch_num;      ms_1;    GBs_1;     ms_2;    GBs_2;   
       8K;        1;        1;      200;      0.81;     24.2;    15.62;      1.3;   
      16K;        1;        1;      200;      1.30;     30.2;    16.69;      2.4;   
      24K;        1;        1;      200;      2.25;     26.2;    16.43;      3.6;   
      32K;        1;        1;      200;      2.73;     28.8;    17.22;      4.6;   
      40K;        1;        1;      200;      3.25;     30.2;    16.97;      5.8;   
      48K;        1;        1;      200;      2.98;     39.6;    17.77;      6.6;   
      56K;        1;        1;      200;      4.04;     34.0;    17.94;      7.7;   
      64K;        1;        1;      200;      4.33;     36.3;    18.78;      8.4;   
      72K;        1;        1;      200;      4.64;     38.1;    19.20;      9.2;   
      80K;        1;        1;      200;      5.22;     37.7;    18.87;     10.4;   
      88K;        1;        1;      200;      5.12;     42.2;    18.70;     11.6;   
      96K;        1;        1;      200;      5.94;     39.7;    18.68;     12.6;   
     104K;        1;        1;      200;      5.89;     43.4;    19.96;     12.8;   
     112K;        1;        1;      200;      6.56;     42.0;    19.87;     13.9;   
     120K;        1;        1;      200;      6.81;     43.3;    20.10;     14.7;   
     128K;        1;        1;      200;      7.43;     42.4;    20.60;     15.3;   

As you can see, a non-batched routine (like CLBlast's AXPY or clBLAS in this case) requires too large values of n to get to decent maximum memory bandwidth, whereas the new routine is in the first experiment already at 24GB/s.

I am concerned that that array could be several thousands of elements long, and it would have to be created for each method call. That could be a significant overhead, both in array creation and in data transfer to the kernels. I guess that's in microseconds, but I could be wrong. What I'm wondering is whether there should be a method version that would just get 2 numbers, and then compute j * x_offset_inc inside the kernel itself?

OK, I understand your concern, but I'm not sure if it is a practical issue. Yes, we'll have to upload the whole array of offsets to the device, so that might be costly. But you'll need only 2 offsets for each batch, so that'll be relatively small, unless you go to really small values of n. An other costly thing might be the global memory accesses in the kernel, but hopefully they can remain in some sort of local constant memory of the device or in a cache. That'll depend on the device how efficient that is. In any case we'll also have to do the same for the array of alpha values...

So I guess for now I'll leave it like this. Feel free already to run some benchmarks (like the one I did above) and see if you can come up with realistic use-cases for which performance is sub-optimal. But remember that having the interface as you suggest really mean a completely different kernel, so that will make the library significantly bigger.

blueberry commented 7 years ago

Thank you very much, Cedric!

In the meantime, I took a more detailed look at LAPACK, and I "discovered" something unexpected: many of the Blas 1-like routines for matrices that I found missing from Blas, are really present as auxiliary Lapack routines!

Here are some examples:

Now, the question is whether you have interest to include this part of lapack into CLBlast. On one hand, it is strictly not in Blas123, on the other, these routines are obviously missing from blas and fit well into everyday Blas work.

In the case you decide to support some of these, I can help with discovery and testing, since I am currently integrating the CPU versions into my Clojure library. Also, If you decide they are a good fit for CLBlast, the good thing is that you do not have to invent the interface, but use lapack for that.

EDIT: maybe it is better to move this part of the discussion to a new issue?

CNugteren commented 7 years ago

I agree. Could you move your message to the #136 issue perhaps? We could change that subject if needed to broaden it a bit.

blueberry commented 7 years ago

@CNugteren I copy/pasted the message, as I couldn't find a way to move it. I also edited the title. Thanks for considering that.

blueberry commented 7 years ago

@CNugteren I also stumbled into the new (as of CUDA 8.0) gemmStridedBatched, which looks similar to my earlier proposal about computing batched offsets instead of constructing and transferring the whole array. There is an article on Nvidia's developer website, with rationale, measurements etc. https://devblogs.nvidia.com/parallelforall/cublas-strided-batched-matrix-multiply/

CNugteren commented 7 years ago

@blueberry Thanks for the link. I will consider implementing such a batched routine as well in CLBlast. Shouldn't take too much effort though, as it is going to be quite similar to what is already there.

For completeness: CLBlast now supports both a batched GEMM and a batched AXPY routine. Here are some example benchmarks from GEMM on an NVIDIA GPU.

blueberry commented 7 years ago

Thanks! It would be great to have it, since they achieved huge speedup over "ordinary" batched GEMM.

CNugteren commented 6 years ago

It's been a while, but there have recently been more requests for a batched GEMM without the overheads of creating the whole offset array. I started working on a strided-batched GEMM, like in cuBLAS. This is the interface I envisage. Any comments?

template <typename T>
StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
                              const size_t m, const size_t n, const size_t k,
                              const T alpha,
                              const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
                              const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
                              const T beta,
                              cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
                              const size_t batch_count,
                              cl_command_queue* queue, cl_event* event)

Note that the offsets array is gone, and so are the alpha/beta arrays: they are scalars again. The added argument is a fixed stride for each matrix, determining where the next batch starts.