NervanaSystems / neon

Intel® Nervana™ reference deep learning framework committed to best performance on all hardware
http://neon.nervanasys.com/docs/latest
Apache License 2.0
3.87k stars 811 forks source link

'magic' is not very descriptive :-D #258

Closed hughperkins closed 8 years ago

hughperkins commented 8 years ago

You have this cool function by dividing by integers by using bitshift, and first multiplying by another number, so you're not limited to dividing by powers of 2, as described in https://gmplib.org/~tege/divcnst-pldi94.pdf

At the moment, this function is called 'magic', but I'm not sure it's very descriptive? I've renamed it to get_div_mul_shift in my own branch: https://github.com/hughperkins/winogradCl/blob/api/winogradcl/util/math_helper.py#L33

def get_div_mul_shift_32(nmax, d)
scott-gray commented 8 years ago

The magic number code comes from here:

http://www.hackersdelight.org/hdcodetxt/magicgu.py.txt

I wasn't too creative about renaming it.

Oh, and you know the kepler cuda code is just direct conv right? It's not clear why you named your repo winogradcl.. but I guess you could be intending to add implementations of that?

hughperkins commented 8 years ago

Oh, and you know the kepler cuda code is just direct conv right? It's not clear why you named your repo winogradcl.. but I guess you could be intending to add implementations of that?

I didnt know that. Thats new infomration :D Thats good actually, means that should be able to make it a ton faster, without needing to use anything other than opencl?

hughperkins commented 8 years ago

(also good because then I can do something more creative then simply search and replace threadIdx.x with get_local_id(0) :-D Though obviously incrdibly embarrassing that I didnt actually spot this earlier :-D )

hughperkins commented 8 years ago

I've renamed it to neonCl for now :-D I think I might as well get the direct ones working, since they're already super fast.... and getting them working will help me learn kind of how your mind works and stuff (eg will know what 'magic' does), which will help with dealing with winograd somehow, however that will be.

scott-gray commented 8 years ago

Actually, Stewart wrote those kernels. And again. .they're meant to be run on sm_30 devices (Amazon Cloud mainly). So they're not nearly as fast as they could be if designed for >=sm_35 (lots more registers and shared memory).

But for winograd, one approach is just write external transforms and use a BLAS lib for the batched gemm. Some of the external transforms are already done for you in the winograd_conv.py file. You would just need the output transforms. The 4x4 transform will work much better externally because it only expands the input/output/delta data by 2.25x.

hughperkins commented 8 years ago

And again. .they're meant to be run on sm_30 devices (Amazon Cloud mainly). So they're not nearly as fast as they could be if designed for >=sm_35 (lots more registers and shared memory).

Ok. I'm in two minds about this. On the one hand, my current semi-stated objective is to get opencl convolution to run as fast as possible on Titan X / NVIDIA devices. So, at least no-one can say 'well opencl is slow, and nvidia nerfed it'. And then it's down to AMD or whoever to do whatever they need to do to make them run fast on their own devices.

On the othre hand, realisticlaly no-one is going to use opencl on nvidia devices. AMD and Intel device users are the only realistic clients, and I heard AMD devices are not overwhelmingly endowed with registers and shared memory (heresay, not having an AMD device mysefl...), so in this sense the sm_30 kernels sound not unreasonable.

But for winograd, one approach is just write external transforms and use a BLAS lib for the batched gemm. Some of the external transforms are already done for you in the winograd_conv.py file. You would just need the output transforms. The 4x4 transform will work much better externally because it only expands the input/output/delta data by 2.25x.

Ok, sounds good. I will probably dabble in using Cedric's CLBlast, as the underlying BLAS implementation, and see how that goes.

hughperkins commented 8 years ago

Hmmm, your idea seems awesome. Reading...

hughperkins commented 8 years ago

Dont think this should be 'open', so closing it. Then will continue the conversation below.

hughperkins commented 8 years ago

Hi Scott,

Started looking at this approach, got as far as calcing U and V using the existing CUDA kernels, ported to OpenCL, https://github.com/hughperkins/neonCl-underconstruction/blob/play-winograd/winograd_cl.py , then pondering how to get M. It looks like the output of the U and V transforms is:

    # U                           Co // 32,       Ci,    6,   6, Co % 32
                         # bytes:           eg 150KB, 4.6K, 768,     128
    # V            # tiles, tiles, N // 32,       Ci,    6,   6,  N % 32
            # bytes                         eg 150KB, 4.6K, 768,     128

Seems like these are already split into conveniently sized blocks of [32, 6, 6, 32], for each of [Ci % 32, 6, 6, Co % 32], and [Ci % 32, 6, 6, N % 32]. Seems like I could simply directly utilize this blocking, and write a kernel to pull down each of these blocks, loop over Ci; and then at the end add the results of each block together? ie, instead of calling out to a third-party GEMM implementation? The sizes of these blocks are too big to fit into __shared__ memory?, but seems like you are predominantly using L2 cache to handle this, which is I think 3MB on a Titan X?, and so should be largely sufficient for this?

hughperkins commented 8 years ago

(Ah, I guess I should have two levels of blocking? Hence the name 'superblock'? But seems like the effort involved in writing sub-blcoking within such regularly-sized superblocks should be relatively low compared to the overkill of somehow plugging in a third-party gemm??? )

hughperkins commented 8 years ago

But ... I cant quite figure out why Ci, ie your C, is not in the innermost dimension? Since we are reducing over this, we'd want it to be contiguous?

(Edit: oh, unless it's something like:

? )

(Edit2: hmmm, maybe xi and nu should be the outermost dimensions? Seems like all the values are conditiioanlly independent on xi and nu, no values are shared across different xi/nu pairs?)