ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.96k stars 9.32k forks source link

mpi : attempt inference of 65B LLaMA on a cluster of Raspberry Pis #2164

Open ggerganov opened 1 year ago

ggerganov commented 1 year ago

Now that distributed inference is supported thanks to the work of @evanmiller in #2099 it would be fun to try to utilize it for something cool. One such idea is to connect a bunch of Raspberry Pis in a local network and run the inference using MPI:

# sample cluster of 8 devices (replace with actual IP addresses of the devices)
$ cat ./hostfile
192.168.0.1:1
192.168.0.2:1
192.168.0.3:1
192.168.0.4:1
192.168.0.5:1
192.168.0.6:1
192.168.0.7:1
192.168.0.8:1

# build with MPI support
$ make CC=mpicc CXX=mpicxx LLAMA_MPI=1 -j

# run distributed inference over 8 nodes
$ mpirun -hostfile ./hostfile -n 8 ./main -m /mnt/models/65B/ggml-model-q4_0.bin -p "I believe the meaning of life is" -n 64

Here we assume that the 65B model data is located on a network share in /mnt and that mmap works over a network share. Not sure if that is the case - if not, then it would be more difficult to perform this experiment.

Looking for people with access to the necessary hardware to perform this experiment

EdwardDali commented 1 year ago

I wonder If this capability could be integrated on https://github.com/pry0cc/axiom. Axiom allows to spin off multiple cloud instances in minutes and for some stuff run distributed scripts/loads. Also allows to terminate the instances realy quick so you can use the instances just for cuple of minutes if needed; this also keep the bill low.

USBhost commented 1 year ago

I could try simulating it. 8 VMs with 8GB of ram.

hchenphd commented 1 year ago

Cool idea, I have some more powerful embedded devices(RISC-V CPU) can integrated in a cluster. Expect this experiment and I am willing to deploy it on RISC-V cluster.

theycallmeloki commented 1 year ago

Ordered the parts today, and should be here same time tomorrow (6 x Raspberry Pi 4 - 8GB variants), in the meantime, I'm setting up a local mpi cluster on VMs to test the inferencing, pick up your Pi(s) now a shortage is coming!

Green-Sky commented 1 year ago

and that mmap works over a network share.

had issues with this in the past (couble of weeks). It first works, I can read the full file no problem, but it suddenly stopps and kills the program.

USBhost commented 1 year ago

and that mmap works over a network share.

had issues with this in the past (couble of weeks). It first works, I can read the full file no problem, but it suddenly stopps and kills the program.

NFS or CIFS?

Green-Sky commented 1 year ago

NFS or CIFS?

CIFS , between 2 ubuntu machines

ggerganov commented 1 year ago

@theycallmeloki Hope I didn't set the expectations too high - even if this runs, the performance is expected to be really terrible. Likely few (tens of) seconds per token for 65B. It's mostly a fun experiment - don't think it would have any practical use.

USBhost commented 1 year ago

@theycallmeloki Hope I didn't set the expectations too high - even if this runs, the performance is expected to be really terrible. Likely few (tens of) seconds per token for 65B. It's mostly a fun experiment - don't think it would have any practical use.

The moment you said raspberry pi I knew we were in the meme train.

ggerganov commented 1 year ago

Well the same experiment can be done with a bunch of gaming GPUs (e.g. 8x 6GB or 6x 8GB) and it would make more sense in terms of performance. But running a 65B model on RPi cluster sounds like a more historic achievement 😄

theycallmeloki commented 1 year ago

@ggerganov Nope, not at all, I was going through the discussions and realized there is some room to add value around the inferencing pipelines, I can also imagine varying the size of the virtual nodes in the Pi cluster and tweaking the partitioning of the model could lead to better tokens/second and this setup costs approximately 1 order of a magnitude cheaper compared to any other off-the-shelf self hostable setup to run a 65B model (I'm looking at the M2 Ultra Studios) so even if it's slow I think it's likely to be worth it for novel reasons alone 😄

JWNoctis commented 1 year ago

@theycallmeloki Hope I didn't set the expectations too high - even if this runs, the performance is expected to be really terrible. Likely few (tens of) seconds per token for 65B. It's mostly a fun experiment - don't think it would have any practical use.

The moment you said raspberry pi I knew we were in the meme train.

FWIW a certain popular but very early report of ~10s/token was a bit exaggerated - Actual number right now is closer to 1.6s/token on Pi 4B 8GB for a q6_k quantized 7B model, which it just barely fits with OS and GUI. Even the Pi is memory bandwidth bound in this use case, and -t 4 is actually a bit slower than -t 3. The board has somewhere around 4GB/s of memory bandwidth. Running headless might also speed things up a bit given the architecture.

@ggerganov Nope, not at all, I was going through the discussions and realized there is some room to add value around the inferencing pipelines, I can also imagine varying the size of the virtual nodes in the Pi cluster and tweaking the partitioning of the model could lead to better tokens/second and this setup costs approximately 1 order of a magnitude cheaper compared to any other off-the-shelf self hostable setup to run a 65B model (I'm looking at the M2 Ultra Studios) so even if it's slow I think it's likely to be worth it for novel reasons alone 😄

The actual cheapest right now might be (used) Ryzen 5800X/5700G, the corresponding motherboard, peripherals, and 64GB of the fastest DDR4 RAM you can find in matched 32GB modules. The latter had become quite cheap after DDR5 rollout, and can be had for the price of some three to four Pi 4 8GB.

But no, that is not nearly as interesting!

theycallmeloki commented 1 year ago

Believe I might have gotten the local environment up and running on the Pis (Confirmed that I ran a hello-world example first to ensure mpi itself was running smoothly)

Moved the 65B model to each pi using scp (256GB microsd cards, so Im hoping I do not need to mmap on network drive) Compiled the binary using MPI instructions and copied the ./main to all the Pis Running with

mpirun -hostfile ./hostfile -n 6 ./main -m /home/laneone/ggml-LLaMa-65B-q4_0.bin -p "I believe the meaning of life is" -n 128

I'm unable to determine how to split this model into tinier chunks so that it can fit on the induvidual Pis, right now I reckon it's trying to load the entire model into each Pi which is probably why it is failing, logs below

main: build = 826 (975221e)
main: seed  = 1689230379
main: build = 826 (975221e)
main: seed  = 1689230379
main: build = 826 (975221e)
main: seed  = 1689230379
main: build = 826 (975221e)
main: seed  = 1689230379
main: build = 826 (975221e)
main: seed  = 1689230379
main: build = 826 (975221e)
main: seed  = 1689230379
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
error loading model: mmap failed: Cannot allocate memory
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
error loading model: mmap failed: Cannot allocate memory
error loading model: mmap failed: Cannot allocate memory
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[7058,1],4]
  Exit code:    1
--------------------------------------------------------------------------
ggerganov commented 1 year ago

Hm, my expectation is that mmap wouldn't require to load the entire model. Does it make a difference if you try this patch:

diff --git a/llama-util.h b/llama-util.h
index 43b6f05..1c0502f 100644
--- a/llama-util.h
+++ b/llama-util.h
@@ -177,10 +177,7 @@ struct llama_mmap {
         int fd = fileno(file->fp);
         int flags = MAP_PRIVATE;
         // prefetch/readahead impairs performance on NUMA systems
-        if (numa) { prefetch = 0; }
-#ifdef __linux__
-        if (prefetch) { flags |= MAP_POPULATE; }
-#endif
+        prefetch = 0;
         addr = mmap(NULL, file->size, PROT_READ | PROT_WRITE, flags, fd, 0);
         if (addr == MAP_FAILED) {
             throw std::runtime_error(format("mmap failed: %s", strerror(errno)));

I'm just poking here - probably someone who understands better how it works can chime in. The backup plan is to split the model into parts and have each node load it, but this would require some extra work to achieve

theycallmeloki commented 1 year ago

Yes, the patch was definitely helpful, prior to this I was parallelly also trying out a 7B model just to ensure the whole process was working, which used to fail with a similar error message as above, but after this patch, the 7B model seems to be running on the Pis

mpirun -hostfile ./hostfile -n 6 ./main -m /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin -p "I believe the meaning of life is" -n 128
main: build = 826 (975221e)
main: seed  = 1689239657
main: build = 826 (975221e)
main: seed  = 1689239657
main: build = 826 (975221e)
main: seed  = 1689239657
main: build = 826 (975221e)
main: seed  = 1689239657
main: build = 826 (975221e)
main: seed  = 1689239657
main: build = 826 (975221e)
main: seed  = 1689239657
llama.cpp: loading model from /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin
llama.cpp: loading model from /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin
llama.cpp: loading model from /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin
llama.cpp: loading model from /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin
llama.cpp: loading model from /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin
llama.cpp: loading model from /home/laneone/vicuna-7b-1.1.ggmlv3.q6_K.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: model size = 7B
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_model_load_internal: mem required  = 7064.42 MB (+ 1026.00 MB per state)
llama_model_load_internal: mem required  = 7064.42 MB (+ 1026.00 MB per state)
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_model_load_internal: mem required  = 7064.42 MB (+ 1026.00 MB per state)
llama_model_load_internal: format     = ggjt v3 (latest)
llama_new_context_with_model: kv self size  =  256.00 MB
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_model_load_internal: mem required  = 7064.42 MB (+ 1026.00 MB per state)
llama_model_load_internal: mem required  = 7064.42 MB (+ 1026.00 MB per state)
llama_new_context_with_model: kv self size  =  256.00 MB
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 4096
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 32
llama_model_load_internal: n_layer    = 32
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 18 (mostly Q6_K)
llama_model_load_internal: n_ff       = 11008
llama_model_load_internal: model size = 7B
llama_model_load_internal: ggml ctx size =    0.08 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_model_load_internal: mem required  = 7064.42 MB (+ 1026.00 MB per state)
llama_new_context_with_model: kv self size  =  256.00 MB

system_info: n_threads = 4 / 4 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 128, n_keep = 0

 I believe the meaning of life is to find out who you are and what you want to do in this world. And then use your talents, strengths

However when I load the 65B model a similar error is thrown when the patch wasnt in place (I'm not sure about this, tbh)

 mpirun -hostfile ./hostfile -n 6 ./main -m /home/laneone/ggml-LLaMa-65B-q4_0.bin -p "I believe the meaning of life is milady because" -n 128
main: build = 826 (975221e)
main: seed  = 1689240041
main: build = 826 (975221e)
main: seed  = 1689240041
main: build = 826 (975221e)
main: seed  = 1689240041
main: build = 826 (975221e)
main: seed  = 1689240041
main: build = 826 (975221e)
main: seed  = 1689240041
main: build = 826 (975221e)
main: seed  = 1689240041
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
error loading model: mmap failed: Cannot allocate memory
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
error loading model: mmap failed: Cannot allocate memory
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
error loading model: mmap failed: Cannot allocate memory
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/laneone/ggml-LLaMa-65B-q4_0.bin'
main: error: unable to load model
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[6572,1],4]
  Exit code:    1
--------------------------------------------------------------------------

I would like to work on the splitting of the model and have each node load just the weights specific to that model, please do give me some rough pointers on how I can approach this, I understand there are ckpt to diffusers and diffusers to ckpt, I could probably patch one side to split the model prior to writing to disk, so that way running both the processes could end up in a split file that can be loaded by each Pi individually

In a parallel track, do you think adding swap storage equivalent to the model size would help? I can imagine it trying to load the entire model on the microsd card and being able to look through the specific portions it needs to for it's inferencing logic might help in this regard

theycallmeloki commented 1 year ago

Update: The swap seems to have done the trick, I added 50GB swap file to each Pi and was able to run the 65B model using:

mpirun -hostfile ./hostfile -n 6 ./main -m /home/laneone/ggml-LLaMa-65B-q4_0.bin -p "I believe the meaning of life is milady because" -n 128

system_info: n_threads = 4 / 4 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 128, n_keep = 0

llama_model_load_internal: mem required  = 38610.47 MB (+ 5120.00 MB per state)
llama_new_context_with_model: kv self size  = 1280.00 MB
 I believe the meaning of life is milady because she is

I think swap was likely not the correct approach as I fear the benefit of being able to use mpi to load only into RAM was so that inference speed could be higher, now that the entire model can be brought up in one Pi technically I can run 6 conversations in parallel, one on each Pi, and they'd be just as slow as right now (about one token every 10-12 mins) thereby I'm leaning towards believing the split model approach could be better and will try on the same lines (Also I can't run k8s on this if I enable swap since etcd stashes into swap and messy things happen, which I would like to be able to do down the line)

USBhost commented 1 year ago

@theycallmeloki try setting vm.overcommit_memory to 1 https://www.kernel.org/doc/html/v5.1/vm/overcommit-accounting.html

ggerganov commented 1 year ago

@theycallmeloki

If everything works as expected, it won't swap. Something is still not ok and I think it is swapping due to the large KV cache. Let's try to limit it by reducing the context length by x8:

diff --git a/llama.cpp b/llama.cpp
index 2d09d6c..3ebdca4 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -129,11 +129,11 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
 static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
 {
     static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,    682ull * MB },
-        { MODEL_7B,   1026ull * MB },
-        { MODEL_13B,  1608ull * MB },
-        { MODEL_30B,  3124ull * MB },
-        { MODEL_65B,  5120ull * MB },
+        { MODEL_3B,    682ull * MB / 8 },
+        { MODEL_7B,   1026ull * MB / 8 },
+        { MODEL_13B,  1608ull * MB / 8 },
+        { MODEL_30B,  3124ull * MB / 8 },
+        { MODEL_65B,  5120ull * MB / 8 },
     };
     return k_sizes;
 }

Btw, how much RAM does each RPi have?

USBhost commented 1 year ago

Looks like it's the time to make a few virtual machines to test this out.

theycallmeloki commented 1 year ago

@ggerganov @USBhost

I have now disabled the 50GB swap file on the Pi each Pi has 8GB RAM (300mb for headless bootup) I added the patch and also set vm.overcommit_memory = 1 The model seems to be able to boot up even without the swap storage now I believe we are now able to generate tokens on a cluster of RPis! 65B model :grinning:

mpirun -hostfile ./hostfile -n 6 ./main -m /home/laneone/ggml-LLaMa-65B-q4_0.bin -p "I believe the meaning of life is milady because" -n 128 -c 256
main: build = 826 (975221e)
main: seed  = 1689406303
main: build = 826 (975221e)
main: seed  = 1689406303
main: build = 826 (975221e)
main: seed  = 1689406303
main: build = 826 (975221e)
main: seed  = 1689406303
main: build = 826 (975221e)
main: seed  = 1689406303
main: build = 826 (975221e)
main: seed  = 1689406303
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama.cpp: loading model from /home/laneone/ggml-LLaMa-65B-q4_0.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 256
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+  640.00 MB per state)
llama_new_context_with_model: kv self size  =  640.00 MB
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 256
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+  640.00 MB per state)
llama_new_context_with_model: kv self size  =  640.00 MB
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 256
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+  640.00 MB per state)
llama_new_context_with_model: kv self size  =  640.00 MB
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 256
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 256
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: n_head     = 64
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_model_load_internal: mem required  = 38610.47 MB (+  640.00 MB per state)
llama_model_load_internal: mem required  = 38610.47 MB (+  640.00 MB per state)
llama_new_context_with_model: kv self size  =  640.00 MB

llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32000
llama_model_load_internal: n_ctx      = 256
llama_model_load_internal: n_embd     = 8192
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 64
llama_model_load_internal: n_layer    = 80
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 22016
llama_model_load_internal: model size = 65B
llama_model_load_internal: ggml ctx size =    0.19 MB
llama_new_context_with_model: kv self size  =  640.00 MB
llama_model_load_internal: mem required  = 38610.47 MB (+  640.00 MB per state)
llama_new_context_with_model: kv self size  =  640.00 MB
system_info: n_threads = 4 / 4 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 256, n_batch = 512, n_predict = 128, n_keep = 0

I believe the meaning of life is milady because you have to

A point of contention that I am fearing is that not all the Pis are involved in the mpirun as only 2/6 of them seem to have spikes in compute as shown here

Screenshot from 2023-07-15 13-12-00

ggerganov commented 1 year ago

@theycallmeloki

I believe we are now able to generate tokens on a cluster of RPis! 65B model 😀

Let's goooo !! 😄

A point of contention that I am fearing is that not all the Pis are involved in the mpirun as only 2/6 of them seem to have spikes in compute as shown here

Yes, this the expected behaviour. The MPI implementation currently just supports pipeline parallelisation, so each node processes part of the pipeline (i.e. few layers of the graph) and passes the results to the next node. This allows each node to "see" only a part of the model and thus distribute it across the cluster.

This is in contrast to tensor parallelisation where all nodes can work in parallel on each part of the graph. However, this would require all nodes to see the entire model, which in our experiment is not viable since we cannot fit the entire model in the RPi RAM.

In any case, I consider the experiment already a success! Let us know what is the inference speed that you observe on this setup

Maybe check if vm.overcommit_memory = 1 is necessary or we can do the same without it.

Also make sure the generation makes sense -- I believe the meaning of life is milady because you have to does not sound very coherent, so there might still be some issue 😄

One more thing: maybe adding --mlock could have some positive effects too (not sure).

And another thing: update to latest master since there was a regression recently (32c54116318929c90fd7ae814cf9b5232cd44c36)

ggerganov commented 1 year ago

Reading your comment again, I might have misunderstood. The nodes should continuously take turns to process the next part of the pipeline. If only the same 2 nodes always have spikes then there is definitely something wrong

theycallmeloki commented 1 year ago

Can I freely parallelize the model further to ensure I am able to densely pack the inferencing? is there a cutoff at which point it'll be more about the DAG being constructed than actually computing the tokens? asking as I am not too sure on how mpi does things, for example:

When I run a hello world example on 6 threads, instead of all of them responding, I get replies from 2 instantly

mpirun -v -hostfile hostfile -np 6 ./hello_world
Hello world from processor spartan1, rank 1 out of 6 processors
Hello world from processor spartan1, rank 2 out of 6 processors
Hello world from processor spartan1, rank 3 out of 6 processors
Hello world from processor spartan1, rank 0 out of 6 processors
Hello world from processor spartan2, rank 4 out of 6 processors
Hello world from processor spartan2, rank 5 out of 6 processors

Just to verify it wasn't an access issue, when I increase to 24 slots, all of them respond

mpirun -v -hostfile hostfile -np 24 ./hello_world
Hello world from processor spartan1, rank 1 out of 24 processors
Hello world from processor spartan1, rank 0 out of 24 processors
Hello world from processor spartan1, rank 2 out of 24 processors
Hello world from processor spartan1, rank 3 out of 24 processors
Hello world from processor spartan3, rank 8 out of 24 processors
Hello world from processor spartan3, rank 9 out of 24 processors
Hello world from processor spartan3, rank 10 out of 24 processors
Hello world from processor spartan4, rank 13 out of 24 processors
Hello world from processor spartan4, rank 14 out of 24 processors
Hello world from processor spartan5, rank 16 out of 24 processors
Hello world from processor spartan3, rank 11 out of 24 processors
Hello world from processor spartan2, rank 6 out of 24 processors
Hello world from processor spartan6, rank 20 out of 24 processors
Hello world from processor spartan2, rank 7 out of 24 processors
Hello world from processor spartan4, rank 15 out of 24 processors
Hello world from processor spartan5, rank 17 out of 24 processors
Hello world from processor spartan6, rank 21 out of 24 processors
Hello world from processor spartan2, rank 4 out of 24 processors
Hello world from processor spartan4, rank 12 out of 24 processors
Hello world from processor spartan5, rank 18 out of 24 processors
Hello world from processor spartan6, rank 23 out of 24 processors
Hello world from processor spartan2, rank 5 out of 24 processors
Hello world from processor spartan5, rank 19 out of 24 processors
Hello world from processor spartan6, rank 22 out of 24 processors
ggerganov commented 1 year ago

Show the contents of hostfile

cc @evanmiller for advice

theycallmeloki commented 1 year ago

This is what I am using for hostfile

192.168.0.60:1
192.168.0.61:1
192.168.0.62:1
192.168.0.63:1
192.168.0.64:1
192.168.0.65:1
ggerganov commented 1 year ago

Try using the following:

192.168.0.60 slots=1
192.168.0.61 slots=1
192.168.0.62 slots=1
192.168.0.63 slots=1
192.168.0.64 slots=1
192.168.0.65 slots=1
theycallmeloki commented 1 year ago

Pretty sure that helped because now there are uniform spikes on at least 1 core on each Pi, I'll keep testing and try out the different parameters like memory overcommit as well

ggerganov commented 1 year ago

The mpirun -v -hostfile hostfile -np 6 ./hello_world command should return each node exactly once. If it is not the case, there is no point to try llama.cpp

theycallmeloki commented 1 year ago

Yes, I think so too, I will try to debug what that might be related to, just to confirm this is what I had used for the hello world, unless there's already a binary I can use to test the mpi side of things I am not sure if the setup is reproducible

#include <stdio.h>
#include <mpi.h>

int main(int argc, char** argv) {
    MPI_Init(NULL, NULL);

    int world_size;
    MPI_Comm_size(MPI_COMM_WORLD, &world_size);

    int world_rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);

    char processor_name[MPI_MAX_PROCESSOR_NAME];
    int name_len;
    MPI_Get_processor_name(processor_name, &name_len);

    printf("Hello world from processor %s, rank %d out of %d processors\n", processor_name, world_rank, world_size);

    MPI_Finalize();
}
ggerganov commented 1 year ago

Looks OK, though I would add a 1 or 2 second sleep somewhere in main since I am not sure what MPI would do if the process ends very quickly. I guess it might launch the next one on the same node -- not sure

theycallmeloki commented 1 year ago

It was definitely the sleep that I had missed in the previous hello world run, confirming all 6 are working correctly

I noticed sometime back that it was just slow, I didn't get it at first, but now I was able to appreciate what was being done. Generated a gif to visualize how the process went

https://github.com/ggerganov/llama.cpp/assets/3431687/bbaf46f7-46f8-4333-be30-04b3e54d04c6

This was so much better! We're doing about 1 token every 10 seconds! based on what I am understanding, context lengths up to 2k should be approachable, if it doesn't fit, likely just adding Pis will allow even bigger models to be consumed with this pattern! :grin:

evanmiller commented 1 year ago

Try using the following:

192.168.0.60 slots=1
192.168.0.61 slots=1
192.168.0.62 slots=1
192.168.0.63 slots=1
192.168.0.64 slots=1
192.168.0.65 slots=1

I believe OpenMPI supports this syntax but MPICH does not.

ggerganov commented 1 year ago

@theycallmeloki Does it work if you mmap the model from a network share? Wondering if we can avoid the SD cards

context lengths up to 2k should be approachable

Yes, we have currently cut the context by x8 since the KV cache becomes significant for 2048 tokens (~5GB), but if you keep adding RPi nodes, you can utilize the full context as with each new node, they would need to load smaller and smaller parts of the model. With the current implementation you can have up to 80 nodes for a 65B LLaMA model - 1 layer per node.

We're doing about 1 token every 10 seconds!

This is surprisingly fast 😄 Btw, you can try with -t 3 - I wouldn't be surprised if it is faster than -t 4

Btw, does the generated text make sense?

theycallmeloki commented 1 year ago

I believe so the generated text does make sense for the most part,

I believe the meaning of life is that we are given a chance to learn, grow in love, and develop our relationships with God and with other people.

I believe the meaning of life is to give life meaning. We all want our lives to mean something and we are constantly trying to find ways of doing so.
The question then comes down to what does it mean for

I believe the meaning of life is to follow your heart and be happy.
I know what you mean, when I was younger in my late 30s and early 40s I had a big ego too, not so much now but back then that's for sure. It could get out of hand. That said, it still gets out of hand sometimes (too often) but I don't care anymore what others think about me - well maybe just a little bit, because when you grow older and have some health issues there is less tolerance from other people and they don't want to help you with your problems like before

The vm.overcommit_memory = 1 flag was unnecessary, so I dropped it.

This was with -t 4

llama_print_timings:        load time = 13541.38 ms
llama_print_timings:      sample time =   430.44 ms /    69 runs   (    6.24 ms per token,   160.30 tokens per second)
llama_print_timings: prompt eval time = 115689.80 ms /     8 tokens (14461.23 ms per token,     0.07 tokens per second)
llama_print_timings:        eval time = 892049.19 ms /    68 runs   (13118.37 ms per token,     0.08 tokens per second)
llama_print_timings:       total time = 1008285.02 ms

This was with t -3

llama_print_timings:        load time = 13048.37 ms
llama_print_timings:      sample time =   682.20 ms /   128 runs   (    5.33 ms per token,   187.63 tokens per second)
llama_print_timings: prompt eval time = 106775.14 ms /     8 tokens (13346.89 ms per token,     0.07 tokens per second)
llama_print_timings:        eval time = 1566084.15 ms /   127 runs   (12331.37 ms per token,     0.08 tokens per second)
llama_print_timings:       total time = 1673736.34 ms

I am looking into mlock flag and the network drive mmap next

theycallmeloki commented 1 year ago

Update:

Starting from master, I did not have to make the patch wrt prefetching readahead, this actually had an end impact on the time it takes to load up the model as I think now it's able to load everything in parallel instead of sequentially (I am not sure if I should keep it or remove it)

mlock is failing to allocate memory (this could be related to the above, but is not a deal breaker since the generation continues to work, albeit slower)

warning: failed to mlock 101449728-byte buffer (after previously locking 1002799104 bytes): Cannot allocate memory
Try increasing RLIMIT_MLOCK ('ulimit -l' as root).
llama_new_context_with_model: kv self size  =  640.00 MB
warning: failed to mlock 101449728-byte buffer (after previously locking 1002799104 bytes): Cannot allocate memory
Try increasing RLIMIT_MLOCK ('ulimit -l' as root).
llama_new_context_with_model: kv self size  =  640.00 MB
warning: failed to mlock 101449728-byte buffer (after previously locking 1002799104 bytes): Cannot allocate memory
Try increasing RLIMIT_MLOCK ('ulimit -l' as root).
llama_new_context_with_model: kv self size  =  640.00 MB

To confirm I have added the patch reducing the context window by 8

ggerganov commented 1 year ago

Starting from master, I did not have to make the patch wrt prefetching readahead, this actually had an end impact on the time it takes to load up the model as I think now it's able to load everything in parallel instead of sequentially (I am not sure if I should keep it or remove it)

It's not clear from the explanation - I expect on latest master each node to start loading part of the model in RAM sequentially. Similar to what we see in the video that you posted. There is a way to make the loading part be done in parallel, but might need extra changes. In any case, I don't think the prefetch patch should be needed anymore

To confirm I have added the patch reducing the context window by 8

This is an important thing to improve for the MPI inference. Currently on master, each node allocates KV cache which has the following size:

2*n_embd*n_ctx*n_layer*sizeof(float16)

However, each node only uses the KV cache for the layers that it computes. So we are over-allocating a lot of memory. If we fix this, then the x8 patch won't be necessary and it will be an important step towards supporting interleaved parallel inference.

ponzienjoyer commented 1 year ago

I've been attempting this with two 16GB Rock Pi 5B devices and have had some success getting inference running across the two nodes thanks to the progress outlined in this thread:

airoboros 33b on two rk3588 devices

However, the inference when running across these two nodes is much slower than a 13B model on a single node; I can do 2 tokens/second on an RK3588 but running with MPI yields the following perf:

slow inference

I've been trying to run with --mlock, but adding this flag to my mpirun llama args causes the process to get killed. I'm assuming this is an OOM occurring while they're trying to mmap the model? Is there a way I can verify this is the cause of my pain, or is there a special way I have to configure mlock to work when running across multiple nodes?

shahizat commented 1 year ago

Hi everyone, has anyone experienced the following issue when running inference using mpirun on a cluster? If so, how did you solve it? Thank you in advance!

llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size =   71.84 MB
ggml_graph_get_node_idx: tensor layer_inp_32 not found

===================================================================================
=   BAD TERMINATION OF ONE OF YOUR APPLICATION PROCESSES
=   PID 915 RUNNING AT llama-mpi-job-mpiworker-1
=   EXIT CODE: 9
=   CLEANING UP REMAINING PROCESSES
=   YOU CAN IGNORE THE BELOW CLEANUP MESSAGES
===================================================================================
[proxy:0:0@llama-mpi-job-mpiworker-0] HYD_pmcd_pmip_control_cmd_cb (proxy/pmip_cb.c:480): assert (!closed) failed
[proxy:0:0@llama-mpi-job-mpiworker-0] HYDT_dmxu_poll_wait_for_event (lib/tools/demux/demux_poll.c:76): callback returned error status
[proxy:0:0@llama-mpi-job-mpiworker-0] main (proxy/pmip.c:127): demux engine error waiting for event
[mpiexec@llama-mpi-job-mpimaster-0] HYDT_bscu_wait_for_completion (lib/tools/bootstrap/utils/bscu_wait.c:109): one of the processes terminated badly; aborting
[mpiexec@llama-mpi-job-mpimaster-0] HYDT_bsci_wait_for_completion (lib/tools/bootstrap/src/bsci_wait.c:21): launcher returned error waiting for completion
[mpiexec@llama-mpi-job-mpimaster-0] HYD_pmci_wait_for_completion (mpiexec/pmiserv_pmci.c:197): launcher returned error waiting for completion
[mpiexec@llama-mpi-job-mpimaster-0] main (mpiexec/mpiexec.c:252): process manager error waiting for completion
cameronbunce commented 1 year ago

I'm running into what I think is a case of the processes not all returning when finished. I wonder if anyone is seeing similar or can help.

I've followed along with the instructions here, including the patch mentioned by @ggerganov on 14 July, and I have three Pi4 8G nodes in a Slurm cluster with OpenMPI

I'm prompting the mpirun commands within slurm batch jobs as such

#!/bin/bash
#SBATCH --nodes=3
#SBATCH --ntasks-per-node=4

cd $SLURM_SUBMIT_DIR

echo "Master node: $(hostname)"

mpirun -hostfile /clusterfs/Projects/llama.cpp/hostfile -n 3 /clusterfs/Projects/llama.cpp/main -m /home/cameron/models/llama-13b.ggmlv3.q4_0.bin -p "The Grudge" -n 1024

and my slurm-.out file looks like this:

blah blah...
[end of text]

llama_print_timings:        load time =  3565.07 ms
llama_print_timings:      sample time =   482.41 ms /   234 runs   (    2.06 ms per token,   485.07 tokens per second)
llama_print_timings: prompt eval time =  7636.31 ms /     6 tokens ( 1272.72 ms per token,     0.79 tokens per second)
llama_print_timings:        eval time = 521833.66 ms /   233 runs   ( 2239.63 ms per token,     0.45 tokens per second)
llama_print_timings:       total time = 530008.40 ms

leading me to think that the output is all finished ( unless the models I have were trained on their own output :) )

but the resources are never released to allow the next job to run. I can cancel it and the next one runs just fine, and I presume I can set a time limit on the job, but I'm wondering if I'm missing an MPI configuration somewhere.

Full output:

out.txt

ggerganov commented 1 year ago

Yeah, the MPI backend is missing a lot of functionality, one of which is terminating upon end of text. Should be easy to fix, but it's not a priority atm.

Rane2021 commented 1 year ago

Does MPI support multi GPU device distribution?

justinh-rahb commented 1 year ago

@ggerganov firstly I can't thank you and every other contributor enough for the hard work you do on this project. You mentioned before about limitations with each node only seeing part of the model at a time due to memory constraint.. just wondering, what if that were not a limitation, what if one were to say have a cluster of M2 Mac Studios? See my thinking at the moment is along the lines of... 30x of those is both cheaper and easier to get than a DGX/HGX with 8xH100s right now and I recently saw something on Twitter that makes me think I may not be the only one trying to find at least $6000 to spare right now 👀.. So I am desperately now trying to think of a way I could use such a cluster to run LLaMa.cpp at something approaching production scale after proving out our experiments by simply buying more of them until it makes sense to go GPU. Maybe MPI isn't the way, if not, any ideas? How's batch inference looking these days? Thanks again for your gift to the world that continues to change the calculus.

Green-Sky commented 1 year ago

@justinh-rahb should be possible, since the real bottleneck for token generation is memory throughput, having 2 machines theoretically doubles that (ignoring overhead).

You mentioned before about limitations with each node only seeing part of the model at a time due to memory constraint.

It is indeed an argument for MPI, but having enough memory is not an argument agains MPI :) ( => )

also, it would be hilarious, if you could outbench datacenter gpus by that much :laughing:

ggerganov commented 1 year ago

In theory, a parallel implementation would split the model tensors across rows to 32 nodes and the MPI code would synchronize the activations after each layer, which would in theory give you ~32x times the memory bandwidth of a single node. And on top of that, each node would need just 1/32 of the total memory.

In practice, there will be synchronization points, network latency and it might be tricky to update the code to support this mode of parallelization. You would also want to do this with a large model to compensate for the overhead from the previous points. Maybe the new 180B Flacon is a good candidate.

But in any case, I find this experiment very interesting and hopefully will give it a try when I have some time and hardware to do that. Running Q8 Falcon 180B on 4x Mac Studio cluster at 20 t/s sounds fun!

justinh-rahb commented 1 year ago

But in any case, I find this experiment very interesting and hopefully will give it a try when I have some time and hardware to do that. Running Q8 Falcon 180B on 4x Mac Studio cluster at 20 t/s sounds fun!

You like to party, fantastic 😬 If you're excited by it, I'm even more excited to read about it when you do!

davidwynter commented 1 year ago

What about 10 Orange Pi5 16GB at $109 each on aliexpress? This gives me 160GB RAM, 80 cores, 40 at 2GHz and 40 at 2.4GHz requiring 150W. I can fit about 16 of these into a long 1U case I have. Would it be possible to use the RKNN SDK (supports quantization) to utilize the NPU on each, with 6 TOPS each, 60 TOPS total? Might be a stupid question, but the idea intrigues me.

Green-Sky commented 1 year ago

Would it be possible to use the RKNN SDK

not atm

(supports quantization)

not ggml's custom quantizations

but otherwise it might work. not sure it give you the value back though. :smile:

theycallmeloki commented 10 months ago

I would like to add to @davidwynter 's point, I think the Orange Pi 5 32GB would be even better, and I think 32 of them would be a very interesting number, to begin with, that's 1TB of RAM! I did some research on the hardware and wrote a kickstarter campaign on mirror, (sorry if this comes across as spam) here, if anyone would like to support me purchasing the hardware and running the experiment

monkeycc commented 10 months ago

Orange Pi 5 (16GB/32GB)

Can use large models

FNsi commented 10 months ago

Orange Pi 5

(16GB/32GB)

Can use large models

There's a phone called honor x50, 6gen1 16g ram in same price range...