soumith / convnet-benchmarks

Easy benchmarking of all publicly accessible implementations of convnets
MIT License
2.67k stars 579 forks source link

Benchmark TensorFlow #66

Closed soumith closed 8 years ago

soumith commented 8 years ago

Google's TensorFlow benchmarks are here!

I've run the benchmarks on the Imagenet Winners. When I saw issues with the numbers, memory etc., I emailed @Yangqing to confirm what I'm seeing, and that it is expected.

With that disclaimer out of the way, here's some things that you should know about TensorFlow (as of the pip version that I installed today):

Coming to the benchmarks:

AlexNet (One Weird Trick paper) - Input 128x3x224x224

Library Time (ms) forward (ms) backward (ms)
CuDNN-R3 (Torch) 96 32 64
Nervana (Neon) 101 32 69
CuDNN-R2 (Torch) 231 70 161
TensorFlow 326 96 230

Overfeat [fast] - Input 128x3x231x231

Library Time (ms) forward (ms) backward (ms)
CuDNN-R3 (Torch) 326 113 213
fbfft (Torch) 342 114 227
CuDNN-R2 (Torch) 810 234 576
TensorFlow 1084 316 768

OxfordNet [Model-A] - Input 64x3x224x224

Library Time (ms) forward (ms) backward (ms)
Nervana 590 180 410
CuDNN-R3 (Torch) 615 196 418
CuDNN-R2 (Torch) 1099 342 757
TensorFlow 1840 545 1295

GoogleNet V1 - Input 16x3x224x224

Library Time (ms) forward (ms) backward (ms)
CuDNN-R2 (Torch) 564 174 390
TensorFlow 590 54 536

Note that at batch size of 16, googlenet with CuDNN-R2 + Torch likely runs into dispatching overhead, so it's an exotic comparison, but not practically very interesting or encouraging.

There you go.

I'm assuming that the first release of TensorFlow is still quite unpolished, and that they will improve it over time with various memory and time optimizations baked in.

soumith commented 8 years ago

The benchmark scripts and raw outputs are located here: https://github.com/soumith/convnet-benchmarks/tree/master/tensorflow

scott-gray commented 8 years ago

The lack of in place operations is rather surprising. Once you have the full DAG it should be rather easy to apply a liveness algorithm to it to optimize tensor allocations. For an example see this: http://www.diku.dk/hjemmesider/ansatte/torbenm/ICD/Register.pdf (just replace register with tensor).

I'm kind of curious if there's any support for automatically compounding operations together or of leveraging kernels that have some compounding built in (like the alpha/beta params of gemm). I'm pretty close to maximizing the amount of compounding that's possible in my benchmark networks. And because I write all my own kernels I can further compound things that aren't possible with closed source libraries like cuDNN. For example, I'm now able to compute the mean along the PQN dimension inside the conv and gemm kernels at no cost. This cuts down the bandwidth required by batch norm in fprop by a third.

Though on the whole I think TensorFlow seems like a great platform to build on. I'd say there's a good chance my kernels will make their way there sooner rather than later. You can find new benchmarks of my latest winograd kernels in the updated paper here: http://arxiv.org/abs/1509.09308

What I'll be working on next is basically going to be taking a lot of what I learned implementing winograd and refreshing all of my conv/pooling/gemm kernels to support very small minibatches at near full utilization. This should have a big impact on the level at which you can scale these networks and the speed at which they converge. Here's a great paper exploring this: http://arxiv.org/abs/1509.04210

yuzcccc commented 8 years ago

Hi, I strongly recommand to add mxnet https://github.com/dmlc/mxnet into comparision which in my opinion may be the fastest DL library :)

mavenlin commented 8 years ago

+1 for benchmarking mxnet, the fastest now.

strongbanker commented 8 years ago

+1 for benchmarking mxnet

fvisin commented 8 years ago

I would also love to see a comparison with Theano http://deeplearning.net/software/theano/ as it is another widely adopted deep learning library.

nkoumchatzky commented 8 years ago

Thanks for benchmarking!

aaronwro commented 8 years ago

+1 would love to see tensorflow benchmarked against mxnet, Theano, Autograd for Torch, and Caffe.

vincentvanhoucke commented 8 years ago

Thanks @soumith! Yes, our only launch criterion for convnets was 'GoogLeNet within distance from CuDNN[R2]', and we've punted on a lot of performance work, including upgrading CuDNN, until after the initial release. Expect a lot of movement on that front in the coming weeks.

soumith commented 8 years ago

@aaronwro @fvisin it's already benchmarked against Torch, Theano, Caffe. Look at the readme on the main page ( https://github.com/soumith/convnet-benchmarks/blob/master/README.md ). I definitely need to pull my socks up and benchmark MXNet and Chainer.

@vincentvanhoucke thanks for your response. I assumed that you'll fix it over the next weeks / months :)

vincentvanhoucke commented 8 years ago

@scott-gray let us know if you need help with compounding / graph rewriting. The graph representation is meant to make these kinds of operations possible, and the common subexpression elimination that TF currently uses is also meant as a demonstration of that. I suspect we might need to do a bit more to provide good APIs to make it easier to bake in compound kernels.

soumith commented 8 years ago

there seems to be some misinterpretation by random people in social media that because I work for Facebook, I'm attacking TensorFlow. That seems super weird, because I love the vision of TensorFlow, and there's no competition (one can write a XXX frontend for a TensorFlow backend).

My benchmarks have always been independently run, and completely neutral, I've been running them forever now, sad that people misinterpret the slightest of things. cc: @vincentvanhoucke

clementfarabet commented 8 years ago

will defend Soumith on this one – he has indeed been running these benchmarks for quite some time, and complete neutrality.

On Wed, Nov 11, 2015 at 11:33 AM, Soumith Chintala <notifications@github.com

wrote:

there seems to be some misinterpretation by random people in social media that because I work for Facebook, I'm attacking TensorFlow. That seems super weird, because I love the vision of TensorFlow, and there's no competition (one can write a XXX frontend for a TensorFlow backend).

My benchmarks have always been independently run, and completely neutral, I've been running them forever now, sad that people misinterpret the slightest of things. cc: @vincentvanhoucke https://github.com/vincentvanhoucke

— Reply to this email directly or view it on GitHub https://github.com/soumith/convnet-benchmarks/issues/66#issuecomment-155836664 .

fvisin commented 8 years ago

@soumith Excellent, thank you!!

vincentvanhoucke commented 8 years ago

@soumith no good deed goes unpunished ;) Please don't let this deter you from providing this valuable service to the community!

Yangqing commented 8 years ago

@soumith , I am sorry that some people interpreted things that way. I've always appreciated your benchmark, which creates a great atmosphere for us to look at bottlenecks and push forward the field as a whole community. We all owe you a big debt of gratitude.

aaronwro commented 8 years ago

@soumith thanks!

jdemouth commented 8 years ago

As always, that's super interesting. Thanks for pushing all of us toward more performance.

tqchen commented 8 years ago

For memory optimizations, what we have found is that inplace optimization does not matter that much, if the allocator is smart enough to do a static allocation before running the graph(as opposed to relying on a dynamic allocator). We have detailed what can be done here

https://mxnet.readthedocs.org/en/latest/developer-guide/note_memory.html

Which I assume applies to computation graph frameworks such as TF, caffe2 and CGT. @vincentvanhoucke @Yangqing

tqchen commented 8 years ago

The general idea is not only to share memory of same shape(i.e. inplace) , but also different shapes and size

rajatmonga commented 8 years ago

@soumith Thanks for running the benchmarks! As @vincentvanhoucke noted in this thread, our goal was to get an early release out so users can start playing with it and provide feedback on what they care about. We are committed to making TensorFlow fast and are actively working on the performance issues you highlight here.

alexbw commented 8 years ago

@soumith You're doing a good deed! Haters gonna hate.

piiswrong commented 8 years ago

I'm a little confused by the number. 1300 samples/sec seems too fast even for alexnet on single TitanX. Is this standard training, e.g. io+forward+backward+update, or just forward+backward?

kyieldmark commented 8 years ago

Nice work.

antinucleon commented 8 years ago

@piiswrong I will help @soumith make the benchmark script.

Anyway we opened everything since beginning. The main purpose is learning from each other but not advertise boring number.

koraykv commented 8 years ago

I will also add my support to Soumith. He has been running these benchmarks for sometime with complete transparency and neutrality.

sermanet commented 8 years ago

@koraykv +1, thanks Soumith!

soumith commented 8 years ago

Someone on reddit suggested that I build tensorflow from source, to fix speed issues. That did not help, It produces the same numbers as the pip version on my alexnet script :

https://gist.github.com/soumith/11acc2f0dbc5212ea372

soumith commented 8 years ago

FWIW, Yangqing's fix to avoid CPU-GPU transfers improved results across the board by ~20%. (I've updated the tables above). The memory issues are unchanged.

XericZephyr commented 8 years ago

+1 for mxnet! Thanks.

yeqinglee commented 8 years ago

+1 for mxnet.

gujunli commented 8 years ago

@soumith I have a naive question, is the Tensor Flow's result based on c++ code or cuDNN v2? I would guess if you run on Titanx tensor flow will rely on some CUDA library?

soumith commented 8 years ago

@gujunli it's based on CuDNN V2.

mattjj commented 8 years ago

@soumith thanks for running and maintaining these benchmarks; they're always thorough and informative!

gujunli commented 8 years ago

@soumith Then I don't understand Why Tensor Flow with cuDNN v2 ends up being so slow? Can you share some of your understanding? I will guess TF still calls cuDNN v2 for the conv/pool/relu/FC layers. Remember from your earlier AlexNet results, cuDNN v2 is 231=70+161, Caffe (native) ConvolutionLayer 324=121+203. However Tensor flow is 326=96+230.

scott-gray commented 8 years ago

Running the network under nvvp (nvidia visual profiler) should be pretty informative. A well tuned network timeline should just be a solid block of kernel calls with no gaps.

gujunli commented 8 years ago

@scott-gray so you think TF scheduling may not be efficient? I need to read TF whitepaper to understand how it works. Any one understands?

scott-gray commented 8 years ago

@gujunli I'm just saying if they're just using stock cuDNNv2 then the only reason it would be slower is if there were gaps in the timeline. Seeing where those gaps occur and any extra host/device memcpy traffic would give you a clearer picture of what's going wrong.

Andy-P commented 8 years ago

@soumith Thanks for this and all the other previous benchmark you took the time to create.

+1 for MxNet

shengwa commented 8 years ago

+1 for mxnet! Thank you so much!!!

Yangqing commented 8 years ago

@gujunli @scott-gray To provide some historical perspective: this is mostly due to legacy choices. Historically, Google Brain has been using the NHWC storage order and a slightly different padding scheme ("SAME/VALID" instead of an explicit padding number). CuDNN, as well as Caffe, uses NCHW order. Note that CuDNN support NHWC interface-wise, but some underlying paths are not implemented, like NHWC convolution backward.

As a result, when calling cuDNN, there are some code that generates intermediate padded and order-switched intermediate tensors. The code was written with Eigen and did not interact very well with nvcc, causing a nontrivial overhead (you can observe that by running the benchmark in an nvvp session as Scott suggested).

We are having people looking into this and the performance should be brought to cuDNN-level.

scott-gray commented 8 years ago

Gah, everyone's using different tensor layouts still. You all need to turn from the dark side and see the speed benefits to using CHWN. Though NHWC is probably better than NCHW at least. You want that inner dimension to be a nice even number to facilitate cleaner alligned memory access, leading to less over-fetch. CHWN gets you better contiguous memory access over all. In recurrent networks with model parallelism having N as the outer dim definitely helps, but most distributed convnets are data parallel where it doesn't matter.

I have some very fast shared memory dimshuffle code if you want it. I use it to make this operation on the filters:

# C <=> K and mirror R,S
F = np.transpose(F[:,::-1,::-1,:], (3,1,2,0))

Turns out a kernel for fprop_conv can work with very little change (or no change if padding and striding are symmetric) to be a kernel for bprop_conv. There's almost no overhead in the dimshuffle since the filters are so small and you completely avoid any atomic adds.

apark263 commented 8 years ago

krizhevsky first demonstrated the benefits of using the CHWN layout in cuda-convnet. In addition to being advantageous for convolutional kernels, it's very beneficial for models like GoogLeNet where inception modules concatenate activations along feature map depth. Using CHWN allows you to write directly into an output buffer in the layout that the subsequent layer will consume (C1 + C2 + C3)HWN.

Yangqing commented 8 years ago

Thanks @scott-gray - having the dim shuffle kernels to improve performance will be great.

One potential issue with CHWN is that during inference time N is often small, so there are two different sets of optimizations to be carried out for large N and small N. NCHW/NHWC usually makes things a bit batch-agnostic, but that's not always true of course.

Yangqing commented 8 years ago

@soumith Regarding the memory issue, we found that if one turns on the best-fit GPU allocator, you would be able to run VGG on batch sizes of 64. I did a quick change if you would like to build and try from source:

git clone https://github.com/Yangqing/tensorflow.git
cd tensorflow
git checkout bfc

There will be more fixes to be submitted by @vrv to enable it more easily (such as during a session creation time) down the road.

scott-gray commented 8 years ago

@Yangqing The shuffle code is here (note that this does not do the RS mirror operation): https://github.com/NervanaSystems/neon/blob/master/neon/backends/float_ew.py#L1481

It uses magic numbers for fast integer division. Here's the code that sets up the kernel params: https://github.com/NervanaSystems/neon/blob/master/neon/backends/layer_gpu.py#L504

The code is adapted from here (the diagram will be helpful): http://devblogs.nvidia.com/parallelforall/efficient-matrix-transpose-cuda-cc/

It's on my list of things to do to generalize it and make it available as a flexible backend operation. But I haven't gotten to it. Theano may also have some good dimshuffle code you can borrow.

Also, most of the code in that float_ew file is devoted to automatically generating extremely efficient compound elementwise/reduction/broadcast/transpose/take operations. It allows you to write complex numpy expressions and have them compile to a single cuda kernel. It even does common sub-expression removal, but sounds like you already have that. This all works off of little optrees that exist in layer code. But I've been meaning to find a way to collect the full program DAG in a clean way. Seems like you guys solved that and that's why I'm interested in TensorFlow. There's so much burden you can shift from the programmer and have automatically optimized via graph traversals.

hjk41 commented 8 years ago

+1 for mxnet. Dynamic GPU memory allocation does have a big impact on performance. A simple memory allocator can significantly reduce the overhead. A smarter allocator which reuses blocks with best-fit can almost eliminate the overhead completely.

vrv commented 8 years ago

@soumith I just pushed https://github.com/tensorflow/tensorflow/commit/1d76583411038767f673a0c96174c80eaf9ff42f, which should allow you to use our best-fit-with-coalescing allocator via the ConfigProto.

Example usage here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/alexnet/alexnet_benchmark.py#L201

We were able to get some of the larger batch sizes working with the BFC Allocator, so probably worth a try.

(We plan to make the BFC allocator the default soon, but it's not fully ready yet to be the default).

stencilman commented 8 years ago

Thanks a lot @soumith for the numbers, super useful!

futurely commented 8 years ago

The creators of cuDNN [1] may help with the performance optimization. @BryanCatanzaro

[1] S. Chetlur, C. Woolley, P. Vandermersch, J. Cohen, J. Tran, B. Catanzaro, and E. Shelhamer. cuDNN: Efficient Primitives for Deep Learning. arXiv preprint arXiv:1410.0759, 2014.