Closed kblomdahl closed 6 years ago
Looking at the compute capabilities table I suggest the following code paths:
f16
(with tensor-core
support for 7.0)q8
.f32
This seems to negate a large part of the advantage that p3.2xlarge
has shown performance wise. Since tensor cores has only shown about a 3.44x speed-up (for batch size 64, might be larger for batch size 128) over the GTX 1080 Ti.
Some questions remain in how to implement this:
q8
's initially and between layers, or do we use the same scale from the start to the end? If we use dynamic ranges, how do we compute and store them?f32
, or can we cast the weights? (For f16
we had to keep the weights as f32
)Some articles that may have suggestions:
Fixed Point Quantization of Deep Convolutional Networks 8-bit Inference with TensorRT
Some notes and observations after reading the cited q8
article and some tensorflow code that deals with this:
alpha
, beta
, and alpha2
parameters to convert between different dynamic ranges for the input, weights, and output.0.011 +/- 0.002
and a maximum standard deviation of 0.19788 +/- 0.034
.NHWC
data format for q8
inference. This is different from what we are using otherwise, and unsure about how we want to transform the input tensor.cublasGemmEx
we can almost do it.After doing some profiling with nvprof we can observe that performing the convolution takes up about 99% of the CUDA runtime. This means that we could just convert the tensors to f32
at the last convolution step and perform the rest in single precision.
Alternatively we need to find some way to convert the q32
output from cublasGemmEx
back into q8
in some efficient way since there is no way to get q8
output from cublasGemmEx
.
There is an undocumented function cublasUint8gemmBias
that does it, but it is using u8
instead of i8
so. See this blog post for some unofficial documentation.
Complications has arisen, from the cuDNN documentation for using the q8
data type in cross correlations:
‣ Input and output features maps must be multiple of 4
This is an issue because our input tensor has 34 features, which is not a multiple of 4. We could remove, or add, two features to fulfill this criteria. There are padding solutions but they feel slow since we need to pad each element in the batch.
Alternatively we could just not do the first layer with q8
, and do it in f32
instead, but then we will have issues converting from f32
to q8
on the GPU.
There are also issues with the tensor format, since we store everything in NCHW
because it is the recommended (and fastest) format for f32
and f16
code. However for q8
we need to store everything in NHWC
, this is not too much of a problem for the input features, but for the weights this is a problem because we do not know their dimensions in the loader.
A mock-up implementation gives the following benchmark results, I had to rip out some of the features to accomplish it and decided on the two territory features since I've long disliked them for giving the network the wrong idea anyway.
test batch_size_01 ... bench: 1,635,884 ns/iter (+/- 8,217)
test batch_size_02 ... bench: 1,660,229 ns/iter (+/- 10,207)
test batch_size_04 ... bench: 1,673,208 ns/iter (+/- 11,483)
test batch_size_08 ... bench: 1,708,661 ns/iter (+/- 17,706)
test batch_size_16 ... bench: 2,522,723 ns/iter (+/- 10,377)
test batch_size_32 ... bench: 4,780,344 ns/iter (+/- 17,684)
test batch_size_64 ... bench: 8,372,624 ns/iter (+/- 319,169)
Speed-up is between 2.27x to 3.20x. This is inline with what I expected from the initially posted benchmarks.
q8
.q8
, the rest of the value and policy head and in f32
. Because of GEMM data type complications.This was ran with a batch size and iteration count of one, so unclear how representative this nvprof
output is. But the profile suggests the hybrid implementation is sound, as 89% of the run-time is still within convolutions and larger batch sizes should need more time for the convolutions than the matrix multiplications.
==22973== Profiling application: ./target/release/dream_go --self-play 1
==22973== Profiling result:
Time(%) Time Calls Avg Min Max Name
89.30% 1.01655s 28197 36.051us 14.924us 43.452us implicit_gemm_int8x4_icudnn_int8x4_128x128_relu_small_tn
5.67% 64.583ms 29643 2.1780us 1.3190us 7.4620us cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams)
1.72% 19.619ms 1446 13.567us 7.9130us 14.507us implicit_gemm_fp32_icudnn_int8x4_128x128_sliced1x2_ldg2_ldg1_relu_interior_tn
1.29% 14.646ms 2169 6.7520us 2.5680us 12.911us void gemv2N_kernel_val<float, float, float, int=128, int=32, int=4, int=4, int=1>(float, float, cublasGemv2Params_v2<float, float, float>)
0.73% 8.3153ms 723 11.501us 9.4740us 116.79us void cudnn::detail::softmax_fw_kernel<int=2, float, float, int=256, int=1, int=0, int=1>(cudnnTensorStruct, float const *, cudnn::detail::softmax_fw_kernel<int=2, float, float, int=256, int=1, int=0, int=1>, cudnnTensorStruct*, int, float, cudnnTensorStruct*, int, int)
0.32% 3.6767ms 1446 2.5420us 2.0130us 4.1300us void add_tensor_kernel_v3<int=2, float, float, int=16, int=16, int=1, int=1, int=3>(cudnnTensorStruct, float*, cudnnTensorStruct, float const *, float, float)
0.29% 3.3267ms 723 4.6010us 4.5460us 4.7900us void cudnn::detail::activation_fw_4d_kernel<float, float, int=16, int=16, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=16, int=16, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
0.22% 2.4697ms 723 3.4150us 3.1930us 3.8180us void add_tensor_kernel_v3<int=2, float, float, int=16, int=16, int=1, int=16, int=4>(cudnnTensorStruct, float*, cudnnTensorStruct, float const *, float, float)
0.19% 2.1340ms 815 2.6180us 659ns 94.053us [CUDA memcpy HtoD]
0.16% 1.8035ms 723 2.4940us 2.3940us 3.5400us void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::tanh_func<float>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::tanh_func<float>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
0.11% 1.2257ms 1446 847ns 763ns 2.7070us [CUDA memcpy DtoH]
==22973== API calls:
Time(%) Time Calls Avg Min Max Name
54.26% 5.96646s 16 372.90ms 9.6330us 5.96627s cudaStreamCreateWithFlags
27.65% 3.03994s 2169 1.4015ms 5.6510us 3.9771ms cudaMemcpyAsync
11.51% 1.26598s 115 11.009ms 374ns 1.26452s cudaFree
3.04% 334.69ms 109 3.0706ms 2.2550us 332.62ms cudaMalloc
2.79% 306.93ms 65793 4.6650us 3.1640us 678.16us cudaLaunch
0.27% 30.201ms 92 328.28us 5.7290us 2.1523ms cudaMemcpy
0.14% 15.803ms 96882 163ns 91ns 591.48us cudaSetupArgument
0.14% 15.568ms 95436 163ns 88ns 692.97us cudaGetLastError
0.10% 11.470ms 65793 174ns 96ns 587.94us cudaConfigureCall
0.02% 2.6155ms 528 4.9530us 114ns 585.19us cuDeviceGetAttribute
0.02% 2.1362ms 723 2.9540us 2.0610us 4.9870us cudaEventRecord
0.02% 1.6892ms 1446 1.1680us 633ns 6.7710us cudaStreamSynchronize
0.01% 1.3439ms 1446 929ns 420ns 6.7890us cudaStreamWaitEvent
0.00% 273.20us 6 45.533us 41.064us 55.928us cuDeviceTotalMem
0.00% 183.77us 6 30.628us 24.009us 45.883us cuDeviceGetName
0.00% 55.266us 81 682ns 527ns 3.5010us cudaEventDestroy
0.00% 41.976us 80 524ns 364ns 2.5630us cudaEventCreateWithFlags
0.00% 41.565us 19 2.1870us 1.6280us 6.7060us cudaStreamDestroy
0.00% 32.209us 3 10.736us 9.9320us 11.198us cudaStreamCreate
0.00% 22.770us 68 334ns 239ns 975ns cudaDeviceGetAttribute
0.00% 20.379us 2 10.189us 9.1550us 11.224us cudaStreamCreateWithPriority
0.00% 14.402us 8 1.8000us 1.7230us 1.9760us cudaThreadSynchronize
0.00% 5.9910us 6 998ns 306ns 1.6020us cudaGetDevice
0.00% 5.0920us 2 2.5460us 1.7670us 3.3250us cudaDeviceSynchronize
0.00% 1.9970us 8 249ns 113ns 598ns cuDeviceGet
0.00% 1.8740us 4 468ns 141ns 1.2780us cuDeviceGetCount
0.00% 1.6090us 1 1.6090us 1.6090us 1.6090us cudaEventCreate
0.00% 1.5900us 2 795ns 696ns 894ns cudaDeviceGetStreamPriorityRange
0.00% 1.3310us 2 665ns 487ns 844ns cuInit
0.00% 643ns 2 321ns 205ns 438ns cuDriverGetVersion
After doing some sensitivity analysis on using quantized data types it has become pretty clear that in order to not have an explosion of errors we must scale both the input and output correctly in addition to the weights.
This is annoying because it means we must store the appropriate scales for each layer. In addition, the scales will change every time we train a new network, making it annoying to support plug-and-play. This leaves us with us with two options:
quantized_types.rs
As can be seen in Table 1, the maximum absolute error differs by as much as 6x just depending on some minimal changes in the input and output scale. Where the final entry in the table represent the optimal scales.
FILTER_SCALE |
INPUT_SCALE |
OUTPUT_SCALE |
Absolute error |
---|---|---|---|
0.1 | 1.0 | 1.3 | 0.011574924 |
0.1 | 0.5 | 1.3 | 0.011574924 |
0.1 | 0.5 | 3.0 | 0.012755904 |
0.1 | 3.0 | 3.0 | - |
0.1 | 2.0 | 3.0 | 0.030472517 |
0.1 | 1.0 | 3.0 | 0.020314991 |
0.1 | 0.5 | 2.0 | 0.011259839 |
0.09 | 0.4 | 2.0 | 0.007244095 |
0.09 | 0.4 | 0.81 | 0.0063780546 |
To examine this, we added operators that measures the mean and variance of the output of each convolutional layer (after batch normalization), and as per the definition of batch normalization the output had a mean of zero and variance of one. This would suggest that we can just pick the z value that represent the percentile of activations we want to keep, for example 1.96 for the 95-percentile and run that over the entire network.
There are only two exception to this so far during training (step 6,000), the policy and value down sampling layers. Both of the mentioned layers has a very non-zero bias (0.9 and 0.4) that cause the mean to drift away from zero, however the variance (as expected) remains one.
Some unofficial documentation on how cudnnConvolutionBiasActivationForward
works with quantized types. It computes the following formula where C
is the convolution operator and act
is the activation operator:
y = act(α · C(W, x) + β · z + b)
When computing using quantized inputs and outputs the variables will have the following types and ranges:
α ∈ f32
β ∈ f32
W ∈ [-127,+127] ∈ i8
x ∈ [-127,+127] ∈ i8
y ∈ [-127,+127] ∈ i8
b ∈ [-127,+127] ∈ f32
For the most part this is not very surprising, with the exception of b, which is scaled to the quantized range as an f32
. In addition the function performs no automatic re-scaling of any of the values, all re-scaling has to be done using α and β.
To determine what α and β should be, one can substitute in the scaling from f32
to i8
into the formula where s?
is a scaling factor in the range (0,Inf]
and solve for the output 127 y / sy
:
act(α · C(127 W / sW, 127 x / sx) + β · 127 z / sz + 127 b / sb)
= act(α · 127² C(W / sW, x / sx) + β · 127 z / sz + 127 b / sb)
= 127 act(α · 127 C(W / sW, x / sx) + β · z / sz + b / sb)
After some simplification (above) we can see the 127 terms on both sides cancel each other out, but we still need to sort out the scaling factors. Since act
is linear we can break this into three sub-problems where each problem must end up with the unit 1 / sy
.
α · 127 C(W / sW, x / sx)
implies that α = (sW sx) / (127 sy)
.β · z / sz
implies that β = sz / sy
.b / sb
implies that sb = sy
.If the scaling factors for x
, y
, and z
are all the same (as one would expect if every layer is batch normalized). Then this simplifies to, where s
is the universal scaling factor:
α = sW / 127
.β = 1
.sb = s
.Been experimenting with computing the scale using a static analysis of the mean and variance of the input tensors since if we make the following assumptions about the convolution operator function:
y = act(α · C(W, x) + β · z + b)
W
is a normal distribution with mean μ and variance σ² with shape [k, c, h, w]
x
is a normal distribution with mean μ' and variance σ'²z
is a normal distribution with mean μ'' and variance σ''²b
is a normal distribution with mean μ''' and variance σ'''²The convolution C(W, x)
can be re-written as a sum of the product-normal distribution [1] of W
and x
where each element in the convolution is the sum of c·h·w
product-normal distributions. Since c·h·w
is typically very large this can be re-written using the central limit theorem into another normal distribution:
C(W, x)
= c h w · N(μ, σ²) · N(μ', σ'²)
= c h w · Nₚ(μ μ', σ² σ'² + σ² μ'² + σ'² μ²)
= N(c h w · μ μ', (c h w) (σ² σ'² + σ² μ'² + σ'² μ²) / √(c h w))
This distribution will henceforth be referred to using the shorthand notation N(μₓ, σₓ²)
.
Adding the blend and bias using the normal sum distribution [2] is trivial as the mean and variance is just the sum of the mean and variance of the operands:
α · C(W, x) + β · z + b
= α · N(μₓ, σₓ²) + β · N(μ'', σ''²) + N(μ''', σ'''²)
= N(α μₓ + β μ'' + μ''', α σₓ² + β σ''² + σ'''²)
The activation function in our network is always a rectified linear units (relu), which produces a truncated normal distribution [3] as output. Whose mean μₜ and variance σₜ² can be estimated using moments [4].
Notice that we expect normal distributions as input, but outputs a truncated normal distribution. Since the output of each activation is feed to the next layer as either x
or z
this method is not theoretically sound but might be good enough in practice because the central limit theorem we apply in the first step will smooth the truncated normal distribution out.
act(α · C(W, x) + β · z + b)
= act(N(α μₓ + β μ'' + μ''', α σₓ² + β σ''² + σ'''²))
= N(μₜ, σₜ²)
The scale of the output variable is then calculated as the 95-th percentile of this distribution, which can be found using the z index:
scale(y)
= μₜ + 1.96 √σₜ²
[1] http://mathworld.wolfram.com/NormalProductDistribution.html [2] http://mathworld.wolfram.com/NormalSumDistribution.html [3] https://en.wikipedia.org/wiki/Truncated_normal_distribution [4] https://github.com/cossio/TruncatedNormal.jl/blob/master/notes/normal.pdf
The introduction of convolutions with dilation to each residual block has showed great success when it comes to global thinking. Unfortunately there is no support for dilation > 1
when using the data type q8
. So if we want to use convolutions with dilation then we cannot use quantized types.
The main workaround for this would be to write our own custom kernels that operate on q8
, this would be good for several other reasons:
ImplicitPrecompGemm
algorithm instead of Winograd
, it should be possible to write an implementation of Winograd for dilation, but presumably NVIDIA just did not bother.I have been playing around with tf.nn.relu6
instead of tf.nn.relu
in the neural network to side-step the scale problem that we have been trying to solve above. So far it has proven very successful as the change in loss is within margin for error, but it solves the scale problem by setting the scale of all output variables to 6.0
.
To strengthen this case we've investigated the optimal scale as determined by the minimal Jensen-Shannon divergence and when using tf.nn.relu
we usually end up with a scale of 8.0
, and with tf.nn.relu6
we always end up with a scale of 6.0
. This is good since we up with a smaller ε, which helps with the precision.
Because of the above I am going to put the DP4A project on-hold until we can do a re-write of the neural network code using either newer cuDNN primitives or using a completely custom kernel.
After a few busy weeks in San Francisco I am back with some inspiration. I started an experiment with different data types during inference to see the performance difference again, and these are my results so far when using a tiny neural network (no residual blocks to make debugging easier).
The main contribution of this attempt is that the implementation actually works (it predicts the same moves as the float and pseudo-half version), and this network can therefore be used as a valid performance comparison. The technical difference between this and previous implementations are:
6.0
for all output layers (using tf.nn.relu6
during training), this simplify the code dealing with scaling factors significantly.CUDNN_DATA_INT8x4
and CUDNN_TENSOR_NCHW_VECT_C
for both the input, output, and filter tensors. This is because of a comment in tensorflow that suggests that it is the only working configuration.We can observe roughly a 2x speed-up when using DP4A for this small network, this may grow as the network becomes more compute than host bound. The reason there is not much of a difference between the different data types for smaller batch sizes is because the cost of copying from the host (unpinned) to the device dominates the runtime.
test batch_size_001 ... bench: 100,570 ns/iter (+/- 6,931)
test batch_size_002 ... bench: 102,021 ns/iter (+/- 996)
test batch_size_004 ... bench: 109,425 ns/iter (+/- 1,137)
test batch_size_008 ... bench: 137,083 ns/iter (+/- 1,785)
test batch_size_016 ... bench: 198,625 ns/iter (+/- 3,661)
test batch_size_032 ... bench: 323,130 ns/iter (+/- 3,735)
test batch_size_064 ... bench: 523,549 ns/iter (+/- 7,902)
test batch_size_128 ... bench: 879,324 ns/iter (+/- 10,828)
test batch_size_256 ... bench: 1,591,066 ns/iter (+/- 32,915)
test batch_size_001 ... bench: 241,942 ns/iter (+/- 2,780)
test batch_size_002 ... bench: 221,768 ns/iter (+/- 1,990)
test batch_size_004 ... bench: 243,368 ns/iter (+/- 2,819)
test batch_size_008 ... bench: 287,454 ns/iter (+/- 2,473)
test batch_size_016 ... bench: 364,674 ns/iter (+/- 26,745)
test batch_size_032 ... bench: 562,866 ns/iter (+/- 4,068)
test batch_size_064 ... bench: 849,575 ns/iter (+/- 5,168)
test batch_size_128 ... bench: 1,333,253 ns/iter (+/- 8,943)
test batch_size_256 ... bench: 2,321,762 ns/iter (+/- 7,660)
test batch_size_001 ... bench: 112,893 ns/iter (+/- 1,711)
test batch_size_002 ... bench: 109,769 ns/iter (+/- 806)
test batch_size_004 ... bench: 147,600 ns/iter (+/- 1,473)
test batch_size_008 ... bench: 212,604 ns/iter (+/- 2,262)
test batch_size_016 ... bench: 319,803 ns/iter (+/- 2,940)
test batch_size_032 ... bench: 539,914 ns/iter (+/- 6,299)
test batch_size_064 ... bench: 866,846 ns/iter (+/- 7,521)
test batch_size_128 ... bench: 1,541,471 ns/iter (+/- 16,759)
test batch_size_256 ... bench: 2,921,560 ns/iter (+/- 15,598)
Further testing suggests that while the network sometimes works when running with 9 residual blocks it is unstable. This suggests that the cuDNN implementation is correct, but that the quantization of the weights and intermediate layers is not perfect. We are currently investigating use of the following function during training of the model to address this issue directly during training instead of trying to fine-tune the weights afterwards:
tf.fake_quant_with_min_max_args
for quantization of convolution results.tf.fake_quant_with_min_max_vars
for quantization of convolution weights.This approach ignore batch normalization, but should be a much closer approximation than the previous one.
Conclusion: Training the network with fake quantization does not seem to help. I am unsure what further conclusions to draw from this result.
Changing the quantization strategy for some weights w
from round(127 w / max(abs(w)))
to use tf.quantize
with mode SCALED
and rounding mode HALF_AWAY_FROM_ZERO
seems to have solved the issues at the moment. Will need to do further testing before I believe it is stable.
The formula used when using tf.quantize
should be the same as we used previously, however it also adjust the scale
factor to account for the quantization and uses it use a different rounding mode. These sounds like small difference, but they can have significant impact on the results when using low precision.
There seems to be some issues with the this implementation and the value head. The observed effect is that occasionally a random move (such as playing in the middle during the opening) will get a high win rate (a few percentage higher than the others) and after that dominates the rest of the search.
It is unclear if this is due to a bug in the quantization of the value head, or if this is an anomaly in the trained network that does not know how to evaluate certain positions. If it is the later then we have most likely not observed it before now because of the huge performance benefits quantization brings, which makes it viable for moves that are not recommended by the policy to get promoted.
This behaviour also seems erratic since when going back to a position with a weird value, it is not always possible to re-produce this issue. This may be because of the use of random symmetries.
In order to determine whether the erratic behaviour was because of training using tf.float16
(or the training data), or if the reason is because of the quantization we tweaked the implementation to use pseudo-half during inference. The result is about a 4.3x drop in performance, but the erratic behaviour of the value head is gone.
This is decent progress, since we can now investigate the differences between the two versions, and maybe determine where the value head diverge.
The problem seems to be because of a slight miscalculation of the lower and upper bound on the tower input and output.
When we quantize a tensor into the closed interval [-127,+127]
we quantize into a total of 255 possible values (including zero), so for the input range [-6,+6]
we get a step size of 12 / 255
. Note that this gives us a lower, and upper bound of ±127 · 12 / 255
which is ±5.976470588
. This is very close to six, so for the most part it will be fine but it introduces a small error that will sometimes spill over and cause misplays (which is the observed effect).
The correct bound for up and down scaling can be derived from this formula, where we set the bound to 6.0 and solve for the scale:
127 · (2 · x / 255) = 6
=> x = 6 / 127 * 255 / 2 = 6.023622047
Out of a total of three games observed so far (self-play and against leela), no erratic misplays has been observed. But more games needs to be observed before any conclusion can be drawn.
I managed to catch the full trace of a broken move, and the issue seems to be the policy head, not the value head which I though so far, since for whatever reason the passing move is given a huge prior value (0.97) which skew the search hugely in favour of the player that does not pass. I am still investigation exactly why the policy was wrong for this case.
After two weeks of debugging, I think we've found the issue that caused this erratic behaviour. It is related to a race condition in the Slots
structure, where the CUDA calls are asynchronous and we return the memory back to the pool at the end of their lifetime (as defined by the scope).
This means that the CUDA function might get executed after the scope has ended (since it is asynchronous), at which point the memory has been returned to the pool and may have been reused by another thread.
The solution is obvious, extend the lifetime of all slots for the full duration of the CUDA calls.
After implementing the aforementioned fix I cannot re-produce the issue anymore, and the neural network plays consistently. After a 50 game round-robin between dg-d-128-2-3
, d-i-128
, and leela
this is the final score:
dg-i-128 v dg-d-128-2-3 (50/50 games)
unknown results: 4 8.00%
board size: 19 komi: 7.5
wins black white avg cpu
dg-i-128 5 10.00% 2 8.00% 3 12.00% 3808.76
dg-d-128-2-3 41 82.00% 20 80.00% 21 84.00% 395.15
22 44.00% 24 48.00%
dg-i-128 v leela (50/50 games)
unknown results: 1 2.00%
board size: 19 komi: 7.5
wins black white avg cpu
dg-i-128 38 76.00% 21 84.00% 17 68.00% 2440.50
leela 11 22.00% 7 28.00% 4 16.00% 71.42
28 56.00% 21 42.00%
dg-d-128-2-3 v leela (50/50 games)
board size: 19 komi: 7.5
wins black white avg cpu
dg-d-128-2-3 47 94.00% 24 96.00% 23 92.00% 276.17
leela 3 6.00% 2 8.00% 1 4.00% 71.19
26 52.00% 24 48.00%
These results are consistent with training a normal 128 wide ResNet using floats. They points towards a few issues however:
However, these issues are better solved in a different thread. For the moment the DP4A implementation is a success.
Based on some recent benchmarks (on GTX 1080 Ti), using 8-bit quantized integers during inference could give us significant performance improvement. No experiments has been performed yet, but other sources suggest that the drop in accuracy from using
q8
instead off16
(orf32
) is insignificant.This is inline with what has been cited by Google and NVIDIA, who are pushing
q8
during inference in their TPU's and TensorRT library.Updated with new benchmarks using a fused conv-add-act kernel
BATCH_SIZE=1
Speed-up is 0.70x
BATCH_SIZE=16
Speed-up is 3.07x
BATCH_SIZE=64
Speed-up is 3.41x
BATCH_SIZE=128
Speed-up is 3.62x