tairov / llama2.mojo

Inference Llama 2 in one file of pure 🔥
https://www.modular.com/blog/community-spotlight-how-i-built-llama2-by-aydyn-tairov
MIT License
2.09k stars 140 forks source link

Woraround parallelize() issues and default to parallel matmul #9

Closed itramble closed 1 year ago

itramble commented 1 year ago

Increases performance from ~190tok/s->~400tok/s on my machine. If you tune the number of threads self.rt = Runtime(nthreads) it gets up to ~500tok/s (best i found was 12 threads).

There are a few issues with parallelize that need to be worked around:

  1. The single arg version of parallelize creates a new runtime for each invokation, which is slow.
  2. The two arg version is much slower than the three arg version because querying the number of cores seems to be slow. Investigating this.
tairov commented 1 year ago

Wow! Looks cool. Thanks @itramble for you efforts. Did you know what means rt.parallelism_level() ?

On my VM I set threads = 3 and I got best performance. I did silly assumption when nthreads = num_cores() // 2.. This expression leads to error /__w/modular/modular/Kernels/mojo/builtin/_startup.mojo:70:1: error: no viable expansions foun

I'm happy to merge this PR as is

itramble commented 1 year ago

Did you know what means rt.parallelism_level() ?

It returns the number of threads that the Runtime was constructed with (by default Runtime() creates a runtime with num_threads=num_cores()).

This expression leads to error /__w/modular/modular/Kernels/mojo/builtin/_startup.mojo:70:1: error: no viable expansions found

Hm this diff works for me:

diff --git a/llama2.mojo b/llama2.mojo
index fa4d4a4..6b24b57 100644
--- a/llama2.mojo
+++ b/llama2.mojo
@@ -282,7 +282,7 @@ struct RunState:
         self.key_cache.alloc_zero()
         self.value_cache = Matrix3(config.n_layers, config.seq_len, config.dim)
         self.value_cache.alloc_zero()
-        self.rt = Runtime()
+        self.rt = Runtime(num_cores() // 2)

 struct TransformerWeights:
tairov commented 1 year ago

End up with self.rt = Runtime(num_cores() // 2)

Before I tried via aliasing alias nthreads = num_cores() // 2 -- that's why it was failing.

So far best performance

itramble commented 1 year ago

Before I tried via aliasing alias nthreads = num_cores() // 2 -- that's why it was failing.

Ah you can only use alias for compile time values and num_cores() is a runtime value.