tracel-ai / models

Models and examples built with Burn
Apache License 2.0
180 stars 24 forks source link

Llama #35

Closed laggui closed 2 months ago

laggui commented 5 months ago

Bringing the first official Llama implementation to Burn! With pre-trained weights in mpk format (hosted on HF hub).

Currently the top-p sampling is done on CPU before decoding since Burn is missing categorical distribution sampling. We could improve that once everything else is done.

Closes #20

laggui commented 5 months ago

Currently downloading the Llama 3 8B Instruct to have a chat mode available for Llama 3 as well.

Also need to update the README to provide a bit more info.

Otherwise everything is ready to go 💪

/edit Actually, a small note: even TinyLlama's record takes ~50sec to load on my machine.. so we could try to improve that but that is on Burn's side.

laggui commented 5 months ago

Tested with wgpu and tch (gpu). I think this is ready for review!

TinyLlama results on my dev machine:

Wgpu

Loading record...
Loaded in 20s
Processing prompt: How many helicopters can a human eat in one sitting?
> It's impossible to know for certain how many helicopters a human can eat in one sitting. However, it's generally accepted that humans have a limited appetite and can only eat a small amount of food at a time.

50 tokens generated (3.5432 tokens/s)

Generation completed in 0m14s

LibTorch<f16>

Loading record...
Loaded in 18s
Processing prompt: How many helicopters can a human eat in one sitting?
> It's impossible to know for certain how many helicopters a human can eat in one sitting. However, it's generally accepted that humans have a limited appetite and can only eat a small amount of food at a time.

50 tokens generated (21.6305 tokens/s)

Generation completed in 0m2s

Pretty big difference 😅

laggui commented 3 months ago

Weights have been updated to use the named mpk format (much faster now that data is treated as bytes with serde). In follow-up PRs we will add quantization and support for Llama 3.1.