ggerganov / ggml

Tensor library for machine learning
MIT License
11.26k stars 1.05k forks source link

Error in `ggml_get_rows` for large tensors with CUDA backend. #877

Open balisujohn opened 4 months ago

balisujohn commented 4 months ago

I found and isolated an error in ggml_get_rows with the cuda backend where for tensors of where the first dimension is greater than 65535, the program fails with the following output:

load_model: using CUDA backend
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce GTX 1070 Ti, compute capability 6.1, VMM: yes
main: compute buffer size: 4096.0000 KB
ggml_cuda_compute_forward: GET_ROWS failed
CUDA error: invalid configuration argument
  current device: 0, in function ggml_cuda_compute_forward at /home/john/errorisol/ggml/src/ggml-cuda.cu:2285
  err
GGML_ASSERT: /home/john/errorisol/ggml/src/ggml-cuda.cu:100: !"CUDA error"
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted (core dumped)

I expect it's related to this: https://stackoverflow.com/questions/12078080/max-number-of-threads-which-can-be-initiated-in-a-single-cuda-kernel.

This doesn't seem to happen for similarly large tensors with the CPU backend.

I provided a reference repo which is a fresh fork of ggml with the error reproducibly and minimally demonstrated in simple-backend.cpp https://github.com/balisujohn/ggml-get-rows-error

At a minimum, I think there should be an explicit guardrail (tensor exceeds allowed dimensions for this operation for this backend), but it would be nice if this operation can be extended to handle tensors that exceed 65535, because without this tortoise.cpp will need to decompose these calls into something with a lot of slicing and concatting, which will probably be less efficient. I'm not against trying to make this change myself, but I want to hear other's thoughts before spending time on this.

slaren commented 4 months ago

The max number of blocks in the y or z dimensions is 65535 (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities). In llama.cpp this dimension typically represents the batch size, so it is always much smaller than that. It would be good to handle this case properly, either by changing the kernel or by launching multiple kernels.