karpathy / llm.c

LLM training in simple, raw C/CUDA
MIT License
24.05k stars 2.69k forks source link

Hardcoded block_size in kernels #261

Open azret opened 5 months ago

azret commented 5 months ago

Can we make the block_size in the kernels more adaptive or parameterized? e.g. 1024 is pretty big for my GPU with 12GB of memory.

I have to run with block_size = 32

void fused_classifier3(float* logits, float* losses,
                      const float* dlosses, const int* targets,
                      int B, int T, int V, int P) {
    const int block_size = 1024;
    const int N = B * T;
    const int grid_size = N;
    fused_classifier_kernel3<<<grid_size, block_size>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);
    cudaCheck(cudaGetLastError());
}
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 551.78                 Driver Version: 551.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4080 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   42C    P3             18W /   73W |       0MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
karpathy commented 5 months ago

Yeah I think we should do that. I have a TODO to look more into the NVIDIA Occupancy Calculator, which I think might be helpful here.

ademeure commented 5 months ago

Not sure I understand why a larger block_size doesn’t work for you, it shouldn’t have anything to do with VRAM, and 12GB is plenty anyway!

What’s the exact error you are getting? As a quick experiment, can you try

fused_classifier_kernel3<<<grid_size, block_size,8192>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);

i.e. force the kernel to ask for 8192 bytes of shared memory (it does use shared memory but we don’t explicitly declare it as such so there’s a small chance that’s related).

azret commented 5 months ago

Same error with 8192

[CUDA ERROR] at file D:\SRC\llm.c\train_gpt2_fp32.cu:1079: too many resources requested for launch

void fused_classifier3(float* logits, float* losses,
                      const float* dlosses, const int* targets,
                      int B, int T, int V, int P) {
    const int block_size = 1024;
    const int N = B * T;
    const int grid_size = N;
    fused_classifier_kernel3<<<grid_size, block_size, 8192>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);
    cudaCheck(cudaGetLastError());
}
  | Name | Value | Type -- | -- | -- | --   | N | 256 | const int   | B | 4 | int   | T | 64 | int

If you consider the maxGridSize, maxThreadDim etc...below from the deviceProps. What limits could I be running into?

devProps

ademeure commented 5 months ago

Typically this happens when the compiled kernel has more registers than the maximum allowed to be able to run that many warps per SM…

There’s 256KiB of register file, so with 1024 threads and 4 bytes per float, that’s a maximum of 64 registers per thread. I’m not sure why/how the compiler ends up using that many registers for you, maybe too aggressive unrolling…

Can you try adding the following to the kernel declaration:

__launch_bounds__(MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP)

so it becomes:

template <typename> __global__ void __launch_bounds__(1024, 1) fused_classifier_kernel3(Type* logits, Type* losses, Type* probs, const Type* dlosses, const int* targets, int B, int T, int V, int P) {

azret commented 5 months ago

The __launch_bounds__(1024, 1) bypassed "the too many resources error". However, all the losses go kaboom with block_size = 1024. But everything passeswith block_size = 32

// With block_size=1024
allocated 474 MiB for AdamW optimizer state v
step 0: loss -inf (took 40.586900 ms)
step 1: loss -inf (took 61.810200 ms)
step 2: loss -nan(ind) (took 96.569500 ms)
step 3: loss -nan(ind) (took 99.150600 ms)
step 4: loss -nan(ind) (took 97.418500 ms)
step 5: loss -nan(ind) (took 99.267500 ms)
step 6: loss -nan(ind) (took 94.665500 ms)
step 7: loss -inf (took 92.860700 ms)
step 8: loss -inf (took 94.221800 ms)
step 9: loss -inf (took 98.063400 ms)
LOSS - NOT OK w/ tol = 2.000000e-04, faults = 5, maxdiff=inf
LOSS[0] -inf8 5.2700078, diff=inf
LOSS[1] -inf8 4.0597078, diff=inf
LOSS[7] -inf8 0.9991478, diff=inf
LOSS[8] -inf8 0.6240808, diff=inf
LOSS[9] -inf8 0.3765118, diff=inf
overall okay: NOT OK

I also noticed we don't do CEIL_DIV for fused_classifier3. Is it designed that way?

azret commented 5 months ago

I'm running fp32 btw. B = 4, T = 64. So pretty small batch.

ademeure commented 5 months ago

Is that at the same time as the previous change setting shared memory at 8192? If so I’m at a loss :(

I have a RTX 4090 so same generation but slightly older drivers, I might try upgrading and see if I see the same issue next week, but…

I just realised your output says WDDM, is this on Windows or WSL? I don’t think that typically matters but some things like timeouts and shared memory allocation restrictions might be different etc…

azret commented 5 months ago

It's Windows.

rosslwheeler commented 5 months ago

I'm seeing the same error on Windows too. I do NOT see it on WSL2.

Same spot as above in :

fused_classifier_kernel3<<<grid_size, block_size>>>(logits, losses, (Type*)NULL, dlosses, targets, B, T, V, P);

    train_gpt2_cu.exe!cudaCheck(cudaError error, const char * file, int line) Line 132  C++
>   train_gpt2_cu.exe!fused_classifier3<__nv_bfloat16>(__nv_bfloat16 * logits, __nv_bfloat16 * losses, const __nv_bfloat16 * dlosses, const int * targets, int B, int T, int V, int P) Line 1247    C++
    train_gpt2_cu.exe!gpt2_forward(GPT2 * model, int * inputs, int * targets, unsigned __int64 B, unsigned __int64 T) Line 1682 C++
    train_gpt2_cu.exe!main(int argc, char * * argv) Line 2278   C++
rosslwheeler commented 5 months ago

Also had to make these changes to get the file to compile with nvcc:

image
rosslwheeler commented 5 months ago

@azret - fixed it! I was missing the PFLAGS in the build - can you check to see if you are building with those new flags?

PeterZhizhin commented 5 months ago

For block sizes, maybe we should look into cudaOccupancyMaxPotentialBlockSize.

https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__HIGHLEVEL.html#group__CUDART__HIGHLEVEL_1gee5334618ed4bb0871e4559a77643fc1

This does the occupancy calculator for a given function.

rosslwheeler commented 5 months ago

@PeterZhizhin - I bumped into the -G debug issue too "[CUDA ERROR] at file train_gpt2.cu:1410: too many resources requested for launch". Cuda debug enabled was a default debug configuration setting for Cuda on Windows.

PeterZhizhin commented 5 months ago

@rosslwheeler yes, I had exactly the same issue. Seems like the kernel uses too many registers? Reducing block size to 512 in debug mode makes the code work.