karpathy / llm.c

LLM training in simple, raw C/CUDA
MIT License
21.28k stars 2.31k forks source link

Speedup CPU training by 10% using Memory Aligned Tensors #630

Open iVishalr opened 1 week ago

iVishalr commented 1 week ago

Speedup CPU training using Memory Aligned Tensors

This PR adds support for allocating aligned memory to speed up training on CPU by ~10%.

Why is this important?

Typically, some of the optimizations like SSE, AVX require memory aligned arrays as input. With no memory alignment, the compiler is forced to use different instructions instead of the fast AVX instructions. This makes the program slower.

Minor Details

On windows, to work with aligned memory, we need to use _aligned_malloc() and _aligned_free() for allocating and freeing memory respectively. These functions are accessible by including <malloc.h>. Hence, when using alignedMallocCheck, user has to use freeCheck for compatibility with windows.

Benchmarks

First we train GPT using train_gpt2.c compiled with the default B = 4 and T = 64 settings. This serves as the baseline for comparison. Then, we replace mallocCheck() for allocating memory for parameters and gradients with alignedAllocCheck() in train_gpt2.c and train again.

Similarly, we perform the same comparisons for B = 8 and T = 128.

To summarize the benchmark results:

B T Memory Aligned Avg seconds per iter pct improvement
4 64 False 1719.5517 0% (baseline)
4 64 True 1542.8133 10.2% (improvement)
8 128 False 7291.3815 0% (baseline)
8 128 True 6724.1067 7.8% (improvement)

All tests were run on AMD Ryzen 5 5600X (6 cores, 12 threads CPU) on Ubuntu 22.04.

--------------------------------------
B = 4, T = 64, Memory Aligned = False
--------------------------------------

$ OMP_NUM_THREADS=6 ./train_gpt2 
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
padded_vocab_size: 50304
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124475904
train dataset num_batches: 1192
val dataset num_batches: 128
num_activations: 73347840
val loss 5.325522
step 0: train loss 4.677773 (took 2439.229593 ms)
step 1: train loss 5.191520 (took 1687.369726 ms)
step 2: train loss 4.438629 (took 1725.409072 ms)
step 3: train loss 4.138455 (took 1687.972396 ms)
step 4: train loss 4.144235 (took 1652.728419 ms)
step 5: train loss 3.834684 (took 1684.643112 ms)
step 6: train loss 4.298058 (took 1705.429080 ms)
step 7: train loss 4.280748 (took 1692.782453 ms)
step 8: train loss 4.249753 (took 1667.858138 ms)
step 9: train loss 4.391605 (took 1664.617172 ms)
val loss 4.416495
step 10: train loss 3.912613 (took 1666.815832 ms)
step 11: train loss 3.737814 (took 1667.392912 ms)
step 12: train loss 3.840918 (took 1658.961637 ms)
step 13: train loss 4.367944 (took 1655.599047 ms)
step 14: train loss 4.130484 (took 1684.592437 ms)
step 15: train loss 4.012578 (took 1706.997949 ms)
step 16: train loss 3.796071 (took 1690.676520 ms)
step 17: train loss 4.355926 (took 1682.068155 ms)
step 18: train loss 3.766851 (took 1687.825547 ms)
step 19: train loss 4.552072 (took 1682.065641 ms)
val loss 4.329332

Enabling memory alignment, we get better performance as shown below:

--------------------------------------
B = 4, T = 64, Memory Aligned = True
--------------------------------------

$ OMP_NUM_THREADS=6 ./train_gpt2 
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
padded_vocab_size: 50304
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124475904
train dataset num_batches: 1192
val dataset num_batches: 128
num_activations: 73347840
val loss 5.325522
step 0: train loss 4.677773 (took 2290.617750 ms)
step 1: train loss 5.191520 (took 1503.890304 ms)
step 2: train loss 4.438629 (took 1498.036819 ms)
step 3: train loss 4.138455 (took 1500.274625 ms)
step 4: train loss 4.144235 (took 1510.857340 ms)
step 5: train loss 3.834684 (took 1504.627098 ms)
step 6: train loss 4.298058 (took 1526.357376 ms)
step 7: train loss 4.280748 (took 1507.125980 ms)
step 8: train loss 4.249753 (took 1497.674549 ms)
step 9: train loss 4.391605 (took 1541.745388 ms)
val loss 4.416495
step 10: train loss 3.912613 (took 1490.735758 ms)
step 11: train loss 3.737814 (took 1495.943422 ms)
step 12: train loss 3.840918 (took 1486.119872 ms)
step 13: train loss 4.367944 (took 1496.391885 ms)
step 14: train loss 4.130484 (took 1497.668309 ms)
step 15: train loss 4.012578 (took 1503.713878 ms)
step 16: train loss 3.796071 (took 1500.391717 ms)
step 17: train loss 4.355926 (took 1498.260206 ms)
step 18: train loss 3.766851 (took 1495.075108 ms)
step 19: train loss 4.552072 (took 1510.758671 ms)
val loss 4.329332

Now with B = 8, T = 128

--------------------------------------
B = 8, T = 128, Memory Aligned = False
--------------------------------------

$ OMP_NUM_THREADS=6 ./train_gpt2 
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
padded_vocab_size: 50304
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124475904
train dataset num_batches: 298
val dataset num_batches: 32
num_activations: 312265728
val loss 5.039043
step 0: train loss 4.593225 (took 8715.017275 ms)
step 1: train loss 4.847952 (took 7775.450341 ms)
step 2: train loss 3.834345 (took 7453.278627 ms)
step 3: train loss 4.154931 (took 7189.926664 ms)
step 4: train loss 4.167930 (took 7307.230269 ms)
step 5: train loss 3.753937 (took 7363.195882 ms)
step 6: train loss 4.064829 (took 7109.582196 ms)
step 7: train loss 3.866308 (took 7164.824190 ms)
step 8: train loss 3.791344 (took 7153.158293 ms)
step 9: train loss 3.812156 (took 7210.657537 ms)
val loss 4.094801
step 10: train loss 4.241575 (took 7158.736672 ms)
step 11: train loss 3.921602 (took 7097.800746 ms)
step 12: train loss 3.651697 (took 7102.333483 ms)
step 13: train loss 3.614878 (took 7123.471092 ms)
step 14: train loss 3.865428 (took 7128.344869 ms)
step 15: train loss 4.155818 (took 7137.803121 ms)
step 16: train loss 3.176409 (took 7213.129287 ms)
step 17: train loss 4.174030 (took 7120.851449 ms)
step 18: train loss 3.625389 (took 7116.744736 ms)
step 19: train loss 4.100704 (took 7186.094851 ms)
val loss 4.002238
-------------------------------------
B = 8, T = 128, Memory Aligned = True
-------------------------------------

$ OMP_NUM_THREADS=6 ./train_gpt2 
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
padded_vocab_size: 50304
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124475904
train dataset num_batches: 298
val dataset num_batches: 32
num_activations: 312265728
val loss 5.039043
step 0: train loss 4.593225 (took 7789.918821 ms)
step 1: train loss 4.847952 (took 6703.190899 ms)
step 2: train loss 3.834345 (took 6693.555808 ms)
step 3: train loss 4.154931 (took 6648.551273 ms)
step 4: train loss 4.167930 (took 6663.275699 ms)
step 5: train loss 3.753937 (took 6680.382998 ms)
step 6: train loss 4.064829 (took 6697.749351 ms)
step 7: train loss 3.866308 (took 6677.181456 ms)
step 8: train loss 3.791344 (took 6645.614926 ms)
step 9: train loss 3.812156 (took 6674.654943 ms)
val loss 4.094801
step 10: train loss 4.241575 (took 6631.632253 ms)
step 11: train loss 3.921602 (took 6629.068677 ms)
step 12: train loss 3.651697 (took 6685.387375 ms)
step 13: train loss 3.614878 (took 6661.784739 ms)
step 14: train loss 3.865428 (took 6657.601348 ms)
step 15: train loss 4.155818 (took 6663.147305 ms)
step 16: train loss 3.176409 (took 6679.609314 ms)
step 17: train loss 4.174030 (took 6727.468650 ms)
step 18: train loss 3.625389 (took 6610.613281 ms)
step 19: train loss 4.100704 (took 6661.746064 ms)
val loss 4.002238