KerfuffleV2 / llm-samplers

A collection of LLM token samplers in Rust
MIT License
15 stars 4 forks source link

v0.0.7 #9

Closed KerfuffleV2 closed 8 months ago

KerfuffleV2 commented 8 months ago

v0.0.7

  1. Fix a bug where Mirostat2 sampled twice.
  2. Remove type variables from samplers.
  3. Rename build.rs to building.rs to avoid confusion.
  4. Minor cleanups.
  5. Add min-p sampler. See: https://github.com/ggerganov/llama.cpp/issues/3483#issuecomment-1783920998
  6. Add top-a sampler. See: https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method
  7. Try to avoid unnecessarily running softmax calculation.
  8. Add a try_from_iter_top_k which pre-prunes the logits while building (and also results in the list starting out sorted)
KerfuffleV2 commented 8 months ago

@ealmloff

I made some changes here to try to avoid re-running softmax if it had already been calculated. Is there a relatively easy way to run the benchmarks you mentioned before and see if there was an improvement? These changes could also definitely benefit from testing.

Or if you want to do it, I'm certainly happy to be lazy. :) Porting from 0.0.6 to 0.0.7 should be pretty straightforward, you basically only have to remove the type variables. (I added some info in the README changes for this pull.)

edit: I created https://github.com/floneum/floneum/pull/77

ealmloff commented 8 months ago

Performance seems to be about the same. Maybe a bit faster? I don't have a formal benchmark. I have just been running the phi example in release mode with a profiler.

Before https://github.com/floneum/floneum/pull/77:

Screenshot 2023-11-06 at 8 41 28 AM

After https://github.com/floneum/floneum/pull/77:

Screenshot 2023-11-06 at 8 37 01 AM
KerfuffleV2 commented 8 months ago

@ealmloff

Hmm, it seems like sorting is actually what dominates time consumption (not too surprising).

I just added a try_from_iter_top_k method for Logits that will pre-filter (and the result is also pre-sorted). I don't know the exact model/command you're using to produce that output, but you could try changing rphi/src/model.rs:

Ok(Logits::try_from_iter(logits)?)

to

Ok(Logits::try_from_iter_top_k(logits, 1000)?)

Basically prune the logits as they are build, but use a high enough value that it won't interfere with regular sampling. I actually think you'll see the time spent in sample_token drop by at least half.

ealmloff commented 8 months ago

That is so much faster, thanks! Choosing the top 1000 tokens out of phi's list of 50,000 tokens makes sampling take almost no time

Screenshot 2023-11-06 at 12 31 13 PM
KerfuffleV2 commented 8 months ago

Oh, nice! I'm glad we were able to fix your original problem even though it was in a roundabout way. Since it's so fast, you can probably even bump that value up a bit to make it affecting the results less likely.

Hopefully I didn't break anything with these changes. I will probably make a release in the next couple days, unless I find an issue.

Thanks for the feedback and benchmarking. It's been very helpful (also poking me to get started on this, it might never have happened otherwise...)