KerfuffleV2 / llm-samplers

A collection of LLM token samplers in Rust
MIT License
15 stars 4 forks source link

More efficient softmax implementation #7

Closed ealmloff closed 8 months ago

ealmloff commented 8 months ago

Thanks for making llm-samplers! I am using it in kalosm to allow custom samplers for candle models. When I was benchmarking the performance of the models about 1/2 of the total time spent on the phi model was sorting the logit array in the softmax function. In Kalosm, after this change it took ~2% of the total inference time.

The current softmax implementation can trigger sorting the logits. This PR changes the softmax implementation to find the maximum logit with a simple loop instead.

ealmloff commented 8 months ago

On a somewhat related note, when sorting the logits, it might be worth trying an unstable sort for better performance if you don't care about the order of equally weighted logits

KerfuffleV2 commented 8 months ago

Hi, thanks for the submission! I haven't had a chance to take a very in depth look yet (sadly, I've been neglecting this project and have to slowly reload it into my brain) but my first thought was that there are various samplers that depend on the logits getting sorted.

In fact, this does cause a number of tests to fail (cargo test):

test tests::sampler::test_freq_presence ... FAILED
test tests::sampler::test_mirostat2 ... FAILED
test tests::sampler::test_repetition ... FAILED
test tests::sampler::test_tail_free ... FAILED
test tests::sampler::test_top_p ... FAILED

I'm a little surprised that Mirostat 1 and top-k didn't fail as well, maybe it's just because their tests are inadequate or they passed by coincidence. It's also definitely not impossible the failure is from bugs/faulty assumptions in the test stuff. However, it seems to run into your assert for Mirostat 2 at least:

---- tests::sampler::test_mirostat2 stdout ----
thread 'tests::sampler::test_mirostat2' panicked at src/types.rs:187:13:
assertion failed: l.prob >= L::zero() && l.prob <= L::one()

if you don't care about the order of equally weighted logits

Unfortunately, I think we have to care about that since otherwise you could get different generations with the same seed. Generally users expect to be able to reproduce generations using a specific seed, and also that property can be really useful for debugging stuff from a developer perspective as well.

KerfuffleV2 commented 8 months ago

Some of the test failures can be fixed with

diff --git a/src/tests.rs b/src/tests.rs
index 19ab30f..cabfa91 100644
--- a/src/tests.rs
+++ b/src/tests.rs
@@ -20,6 +20,7 @@ fn test_sampler_ll<S: Sampler<u32, f32>>(
         .expect("Bad logits");
     if use_sm {
         logits.softmax().expect("Softmax failed");
+        logits.ensure_sorted().expect("Sort failed");
     }
     let result_logits = sampler.sample(res, &mut logits).expect("Sampler error");
     vf(sampler, result_logits, expected)
@@ -79,7 +80,9 @@ fn validate_sm(
     logits: &mut Logits<u32, f32>,
     expected: &[f32],
 ) {
-    validate(sampler, logits.softmax().expect("Softmax failed"), expected);
+    logits.softmax().expect("Softmax failed");
+    logits.ensure_sorted().expect("Sort failed");
+    validate(sampler, logits, expected);
 }

 fn validate_eq(

Repetition penalty, etc shouldn't really care about whether the logits are sorted as far as I recall. So those failures basically just come down to me assuming calling softmax will result in sorted logits. However, I feel like most of the other more complicated samplers are going to require sorting - mirostat, top-p, top-k, etc. Also, I'm pretty sure picking a token with rand_distrib will also need to have them sorted. I will have to look more closely at those samplers to double check all that.

ealmloff commented 8 months ago

Sorry, I should have checked the tests first (I'm getting too used to CI catching my errors). Softmax is currently the only sorting that happens in the mirostat2 chain. When I looked at the code, it seemed like sorting wasn't necessary in softmax, but if the surrounding samplers, I don't think this optimization is useful

KerfuffleV2 commented 8 months ago

Absolutely no problem at all and I definitely appreciate that you wanted to contribute (also a fan of Dioxus, though I haven't don't stuff with it in a while!)

There are probably cases that can be optimized, like unnecessarily re-sorting or running softmax in a sampler chain that has multiple samplers that might want it. It also might be useful to make sorting and softmax user-visible samplers so people could use them if they wanted to.

I'm not 100% sure about your usecase, but if you're using a chain of samplers that runs softmax but you don't actually need it, then it may be possible to optimize stuff that way. If you don't mind answering, I'm curious able the sampler chain you were using when you noticed softmax accounting for so much time.

ealmloff commented 8 months ago

I'm not 100% sure about your usecase, but if you're using a chain of samplers that runs softmax but you don't actually need it, then it may be possible to optimize stuff that way. If you don't mind answering, I'm curious able the sampler chain you were using when you noticed softmax accounting for so much time.

I'm using this chain of sampler:

SamplerChainBuilder::from([
    (
        "repetition",
        SamplerSlot::new_chain(
            move || {
        Box::new(
            SampleRepetition::default()
                .penalty(repetition_penalty)
                .last_n(repetition_penalty_range as usize),
        )
            },
            [],
        ),
    ),
    (
        "freqpresence",
        SamplerSlot::new_chain(
            move || Box::new(SampleFreqPresence::default().last_n(64)),
            [],
        ),
    ),
    (
        "seqrepetition",
        SamplerSlot::new_chain(move || Box::<SampleSeqRepetition>::default(), []),
    ),
    (
        "mirostat2",
        SamplerSlot::new_single(
            move || Box::new(SampleMirostat2::default().tau(tau).eta(eta).mu(mu)),
            Option::<SampleTopK>::None,
        ),
    ),
    (
        "temperature",
        SamplerSlot::new_single(
            move || Box::new(SampleTemperature::default().temperature(temperature)),
            Option::<SampleTemperature>::None,
        ),
    ),
    (
        "randdistrib",
        SamplerSlot::new_static(|| Box::<SampleRandDistrib>::default()),
    ),
]).into_chain()
KerfuffleV2 commented 8 months ago

Hmm, that might not be working the way you want (also possible I'm crazy and forgot how my own stuff works). It looks like you have Mirostat, then temperature, then random distribution sampling in the chain. However, Mirostat basically does top-k and then picks a token. So it appears you're not using the token the mirostat sampler would pick in the case where Mirostat is enabled. Also it seems strange to run temperature after Mirostat already picked a token.

There's some info about suggested chains/ordering here: https://docs.rs/llm-samplers/latest/llm_samplers/#suggested-chainsordering - note, I'm definitely not a sampler expert. I basically just ported the the algorithms from llama.cpp and the recommendations here are according to what they do also.

Basically if you use Mirostat 1 or 2 you want to make sure you run that second set of samplers in the order shown (and other samplers are considered incompatible). If you don't use Mirostat then you want to run the first set of samplers shown, in that order. The Mirostat samplers pick a token, so you don't want have a rand distrib sampler in that case.

This is annoying, right? Because you have to build different chains depending on whether using Mirostat or not. Unfortunately that's just how it is, although it's a pain. The good news is, it basically already has been dealt with by yours truly when I added support for llm-samplers to the llm project over here: https://github.com/rustformers/llm/blob/main/crates/llm-base/src/samplers.rs

By the way, you may want to change this to avoid future confusion:

        SamplerSlot::new_single(
            move || Box::new(SampleMirostat2::default().tau(tau).eta(eta).mu(mu)),
            Option::<SampleTopK>::None,
        ),

It actually works because that accepts dyn BuildableSampler but that's a case of poor design on my part: you shouldn't be able to pass a different type there. (In case it's not obvious what I'm talking about, that's a SampleMirostat2 slot but the None is of Option<SampleTopK>)

ealmloff commented 8 months ago

Oops, thanks for letting me know! I changed it to this:

SamplerChainBuilder::from([
    (
        "repetition",
        SamplerSlot::new_chain(
            move || {
                Box::new(
                    SampleRepetition::default()
                        .penalty(repetition_penalty)
                        .last_n(repetition_penalty_range as usize),
                )
            },
            [],
        ),
    ),
    (
        "freqpresence",
        SamplerSlot::new_chain(
            move || Box::new(SampleFreqPresence::default().last_n(64)),
            [],
        ),
    ),
    (
        "seqrepetition",
        SamplerSlot::new_chain(move || Box::<SampleSeqRepetition>::default(), []),
    ),
    (
        "temperature",
        SamplerSlot::new_single(
            move || Box::new(SampleTemperature::default().temperature(temperature)),
            Option::<SampleTemperature>::None,
        ),
    ),
    (
        "mirostat2",
        SamplerSlot::new_single(
            move || Box::new(SampleMirostat2::default().tau(tau).eta(eta).mu(mu)),
            Option::<SampleMirostat2>::None,
        ),
    ),
])
.into_chain()
KerfuffleV2 commented 8 months ago

That works as long as the Mirostat option is always specified, but if it's not, that slot will be empty and you'll have no token-picking sampler. That's basically the issue I was talking about: you need to build a different chain depending on whether there's Mirostat or not. If Mirostat is enabled, you want the Mirostat sampler at the end (and no other token-picking sampler). Otherwise you want the rand distrib sampler at the end most likely.

Unfortunately, this part of the crate isn't really documented very well. Reading my own docs trying to remember what I did, I can see it definitely falls short.

edit: If you want a sort of hacky approach that probably works in this case (as opposed to the more complete stuff from llm) then I think the items in the chain implement HasSamplerMetaData so you could check if the last item is named mirostat (whatever the exact name is) and if not then you know Mirostat didn't get specified so you can push in a rand distrib sampler.

Since you're trying to do configurable plugin type stuff, you'll probably need to resort to the full approach sooner or later though.

KerfuffleV2 commented 8 months ago

^^^^ Just in case you didn't see the edit there with a possible (fairly) easy solution. (If you already saw it, you can just disregard this.)

ealmloff commented 8 months ago

edit: If you want a sort of hacky approach that probably works in this case (as opposed to the more complete stuff from llm) then I think the items in the chain implement HasSamplerMetaData so you could check if the last item is named mirostat (whatever the exact name is) and if not then you know Mirostat didn't get specified so you can push in a rand distrib sampler.

Thanks, I just made the samplers static for now. Currently this is only for the simplest rust API (not integrated with the UI yet). If I expose samplers in the UI, I will probably end up stealing some of your code in llm.rs. There are currently three versions of the API Kalosm supports: 1) The default Mirostat version with some very basic configuration 2) The user supplies there own sampler implementation (which could be built from a new sampler chain) 3) The user supplies a parser which the output must conform to. I have a wrapper that converts a parser to a sampler, but as the context size grows parsing the entire text generated up to the current point becomes a bottleneck so I need to introduce either a stateful sampler or stick with the more raw API I currently use. I know llm.rs has plans to support the same kind of structured generation at some point (https://github.com/rustformers/llm/issues/235). If structured generation is something you would be interested in me upstreaming into this library, let me know and I can open an issue to discuss the API.

KerfuffleV2 commented 8 months ago

If structured generation is something you would be interested in me upstreaming into this library

Definitely! This actually ties into something else: the poor state of the sampler resource system. You'd really want a good way to store your parser state (so you can avoid reparsing everything repeatedly like you said).

I started trying to revamp it, but ran into issues and kind of gave up. I'll write you a better response in a few days (please ping me if you don't hear back). Even before you said that, I was already thinking with your Dioxus experience probably would have some good ideas for how to solve that (assuming you'd be interested in helping with/contributing to that kind of thing).