Open azret opened 6 months ago
@azret - what about using C11 Atomics?
Would you be able to port it to atomic_compare_exchange_weak?
Sure - will let you know.
@rosslwheeler Unless explicitly mentioned, I don’t think it’s safe to assume C11 support. There’s pretty broad compiler support for atomics via intrinsics (GCC/Clang/etc.)—maybe use those instead? You could also check if the __STDC_NO_ATOMICS__
macro constant is defined, use C11 atomics if available, and fallback to the compiler intrinsics otherwise…
@azret - didn't get any speedup - was it supposed to? This is the same as your code above (basically). This will need the MSVC OpenMP changes in the source which I think you mentioned you've already done (if so, please don't do the alternative workaround using /TP since it fails with these C11 Atomics).
#include <stdatomic.h>`
...
void atomic_add(_Atomic float* dest, float val) {
float old = *dest;
float new_value;
do {
new_value = old + val;
} while (!atomic_compare_exchange_weak(dest, &old, new_value));
}
...
void layernorm_backward(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* mean, float* rstd,
int B, int T, int C) {
...
// dbias[i] += dout_bt[i];
atomic_add((_Atomic float*)&dbias[i], dout_bt[i]);
// gradient contribution to weight
// dweight[i] += norm_bti * dout_bt[i];
atomic_add((_Atomic float*)&dweight[i], norm_bti * dout_bt[i]);
,,,
// dinp_bt[i] += dval;
atomic_add((_Atomic float*)&dinp_bt[i], dval);
You'll need these command line options too: /openmp:llvm /fp:fast /experimental:c11atomics /std:c++17 /std:c17
@jonathanmarvens - no worries. We're just experimenting at this point.
void atomic_add(_Atomic float* dest, float val) {
float old = *dest;
float new_value;
do {
new_value = old + val;
} while (!atomic_compare_exchange_weak(dest, &old, new_value));
}
Does not look right. Can you make a small program to add 10000 floating points in parallel. And compare it with adding 10000 floating points in a single thread. Your sum should match.
Try this:
void test() {
// verify basic case.... expect about ~ 1.0
float sum = 0.999991;
atomicAdd(&sum, 0.000009);
int B = 1024;
int T = 1024;
int C = 1024;
float VALUE_1 = 0.0f;
int b;
int t;
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
VALUE_1 += 0.98992113;
}
}
float VALUE_2 = 0.0f;
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
atomic_add(&VALUE_2, 0.98992113);
}
}
printf("%f6 %f6", VALUE_1, VALUE_2);
}
You forgot to move the int's out of the loop :-)
They are the same...1045264.5000006 1045264.5000006
Btw some of these kernels i was being very lazy, e.g. encoder_backward, it's very possible we don't need to use atomics at all. Personally I'm a bit sus about them, had some weird issues earlier with correctness.
(The repo is very new, I was just trying to get it to check w.r.t. PyTorch quickly, and then some of these kernels don't make it the biggest time sink so they don't get re-visited)
We might need atomic accumulation in the back pass if the number of collisions to the same bucket is significant that it will screw up the gradient. And I say might because I'm not sure yet. We'd have to measure and see. It might be training data specific too. Not very safe or predictable. The accumulated difference is very significant when you force collisions to occur.
Here is a simple test run to compare.
layernorm_backward_cpu is the CPU reference. layernorm_backward_cpu_omp is paralellized of B and T
// layernorm_backward_cpu
struct timespec start, end;
clock_gettime(CLOCK_MONOTONIC, &start);
for (int i = 0; i < NUM_ITERS; i++) {
layernorm_backward_cpu(
dinp, dweight, dbias, // output
dout, inp, weight, mean, rstd, // input
B, T, C);
}
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
printf("layernorm_backward_cpu (took %f ms)\n", time_elapsed_s * 1000);
// layernorm_backward_cpu_omp
float* d_dinp = make_zeros_float(B * T * C);
float* d_dweight = make_zeros_float(C);
float* d_dbias = make_zeros_float(C);
clock_gettime(CLOCK_MONOTONIC, &start);
for (int i = 0; i < NUM_ITERS; i++) {
layernorm_backward_cpu_omp(
d_dinp, d_dweight, d_dbias, // output
dout, inp, weight, mean, rstd, // input
B, T, C);
}
clock_gettime(CLOCK_MONOTONIC, &end);
time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
printf("layernorm_backward_cpu_omp (took %f ms)\n", time_elapsed_s * 1000);
without atomicAdd
layernorm_backward_cpu (took 2312.000000 ms) layernorm_backward_cpu_omp (took 797.000000 ms)
Checking correctness... dinp: -63.765575 -63.765575 -5.533450 -5.533450 64.612320 64.612320 6.844718 6.844718 -61.464909 -61.464909 OK dweight: -5634.315430 -5562.363281 MISMATCH of dweight at 0: REF: -5634.315430 vs NEW: -5562.363281, DIFF: 71.952148 -1950.344482 -1888.767212 MISMATCH of dweight at 1: REF: -1950.344482 vs NEW: -1888.767212, DIFF: 61.577271 108.557045 91.927727 MISMATCH of dweight at 2: REF: 108.557045 vs NEW: 91.927727, DIFF: 16.629318 -729.878113 -704.006714 MISMATCH of dweight at 3: REF: -729.878113 vs NEW: -704.006714, DIFF: 25.871399 -2341.773682 -2262.675049 MISMATCH of dweight at 4: REF: -2341.773682 vs NEW: -2262.675049, DIFF: 79.098633 MISMATCH of dweight at 5: REF: 76.753174 vs NEW: 68.295410, DIFF: 8.457764 MISMATCH of dweight at 6: REF: 330.326080 vs NEW: 383.033966, DIFF: 52.707886 MISMATCH of dweight at 7: REF: -2848.663330 vs NEW: -2732.886230, DIFF: 115.777100 MISMATCH of dweight at 8: REF: 2761.050293 vs NEW: 2718.891602, DIFF: 42.158691 MISMATCH of dweight at 9: REF: 584.252319 vs NEW: 510.019684, DIFF: 74.232635 dbias: 2925.032715 2855.006104 MISMATCH of dbias at 0: REF: 2925.032715 vs NEW: 2855.006104, DIFF: 70.026611 4758.886719 4685.164551 MISMATCH of dbias at 1: REF: 4758.886719 vs NEW: 4685.164551, DIFF: 73.722168 8586.708008 8341.341797 MISMATCH of dbias at 2: REF: 8586.708008 vs NEW: 8341.341797, DIFF: 245.366211 9051.006836 8799.977539 MISMATCH of dbias at 3: REF: 9051.006836 vs NEW: 8799.977539, DIFF: 251.029297 -6168.655762 -6065.003418 MISMATCH of dbias at 4: REF: -6168.655762 vs NEW: -6065.003418, DIFF: 103.652344 MISMATCH of dbias at 5: REF: 2633.766357 vs NEW: 2505.043945, DIFF: 128.722412 MISMATCH of dbias at 6: REF: 13067.747070 vs NEW: 12786.553711, DIFF: 281.193359 MISMATCH of dbias at 7: REF: 9480.796875 vs NEW: 9303.245117, DIFF: 177.551758 MISMATCH of dbias at 8: REF: -6072.158691 vs NEW: -5987.854004, DIFF: 84.304688 MISMATCH of dbias at 9: REF: -789.956421 vs NEW: -698.946106, DIFF: 91.010315
WITH atomicAdd
layernorm_backward_cpu (took 2281.000000 ms) layernorm_backward_cpu_omp (took 1438.000000 ms) Checking correctness... dinp: -63.765575 -63.765575 -5.533450 -5.533450 64.612320 64.612320 6.844718 6.844718 -61.464909 -61.464909 OK dweight: -5634.315430 -5634.321289 MISMATCH of dweight at 0: REF: -5634.315430 vs NEW: -5634.321289, DIFF: 0.005859 -1950.344482 -1950.346680 MISMATCH of dweight at 1: REF: -1950.344482 vs NEW: -1950.346680, DIFF: 0.002197 108.557045 108.558006 -729.878113 -729.877075 MISMATCH of dweight at 3: REF: -729.878113 vs NEW: -729.877075, DIFF: 0.001038 -2341.773682 -2341.782959 MISMATCH of dweight at 4: REF: -2341.773682 vs NEW: -2341.782959, DIFF: 0.009277 MISMATCH of dweight at 6: REF: 330.326080 vs NEW: 330.329285, DIFF: 0.003204 MISMATCH of dweight at 7: REF: -2848.663330 vs NEW: -2848.661865, DIFF: 0.001465 MISMATCH of dweight at 8: REF: 2761.050293 vs NEW: 2761.036621, DIFF: 0.013672 MISMATCH of dweight at 12: REF: 4411.086426 vs NEW: 4411.091797, DIFF: 0.005371 MISMATCH of dweight at 15: REF: 2047.633301 vs NEW: 2047.635376, DIFF: 0.002075 MISMATCH of dweight at 16: REF: 3462.655273 vs NEW: 3462.651611, DIFF: 0.003662 dbias: 2925.032715 2925.036621 MISMATCH of dbias at 0: REF: 2925.032715 vs NEW: 2925.036621, DIFF: 0.003906 4758.886719 4758.886230 8586.708008 8586.682617 MISMATCH of dbias at 2: REF: 8586.708008 vs NEW: 8586.682617, DIFF: 0.025391 9051.006836 9051.012695 MISMATCH of dbias at 3: REF: 9051.006836 vs NEW: 9051.012695, DIFF: 0.005859 -6168.655762 -6168.645508 MISMATCH of dbias at 4: REF: -6168.655762 vs NEW: -6168.645508, DIFF: 0.010254 MISMATCH of dbias at 5: REF: 2633.766357 vs NEW: 2633.772217, DIFF: 0.005859 MISMATCH of dbias at 6: REF: 13067.747070 vs NEW: 13067.738281, DIFF: 0.008789 MISMATCH of dbias at 7: REF: 9480.796875 vs NEW: 9480.806641, DIFF: 0.009766 MISMATCH of dbias at 8: REF: -6072.158691 vs NEW: -6072.162598, DIFF: 0.003906 MISMATCH of dbias at 9: REF: -789.956421 vs NEW: -789.950256, DIFF: 0.006165 MISMATCH of dbias at 10: REF: -2587.304688 vs NEW: -2587.274902, DIFF: 0.029785
I'm trying to parallelize the layernorm_backward and encoder_backward. I need some help making the CPU atomicAdd portable. I know there is already one for CUDA.
D:\Repos\llm.c>test_gpt2.exe [GPT-2] max_seq_len: 1024 vocab_size: 50257 num_layers: 12 num_heads: 12 channels: 768 num_parameters: 124439808 [State] batch_size: 4 seq_len: 64 num_activations: 73323776 -43.431702 -43.431740 -39.836426 -39.836460 -43.065937 -43.066002 OK (LOGITS) LOSS OK: 5.269984 5.270009 dwte TENSOR OK dwpe TENSOR OK dln1w TENSOR OK dln1b TENSOR OK dqkvw TENSOR OK dqkvb TENSOR OK dattprojw TENSOR OK dattprojb TENSOR OK dln2w TENSOR OK dln2b TENSOR OK dfcw TENSOR OK dfcb TENSOR OK dfcprojw TENSOR OK dfcprojb TENSOR OK dlnfw TENSOR OK dlnfb TENSOR OK step 0: loss 5.269984 (took 9829.000000 ms) step 1: loss 4.059653 (took 10047.000000 ms) step 2: loss 3.374920 (took 10157.000000 ms) step 3: loss 2.800694 (took 10109.000000 ms) step 4: loss 2.315437 (took 10094.000000 ms) step 5: loss 1.849182 (took 10125.000000 ms) step 6: loss 1.394891 (took 13360.000000 ms) step 7: loss 0.999076 (took 15719.000000 ms) step 8: loss 0.624470 (took 15250.000000 ms) step 9: loss 0.376849 (took 15469.000000 ms) loss ok at step 0: 5.269984 5.270007 loss ok at step 1: 4.059653 4.059707 loss ok at step 2: 3.374920 3.375123 loss ok at step 3: 2.800694 2.800783 loss ok at step 4: 2.315437 2.315382 loss ok at step 5: 1.849182 1.849029 loss ok at step 6: 1.394891 1.394656 loss ok at step 7: 0.999076 0.999147 loss ok at step 8: 0.624470 0.624080 loss ok at step 9: 0.376849 0.376511 overall okay: 1