Open pixelspark opened 1 year ago
Having looked at this some more, one other issue I encounter with the current TopPTopK
sampler is that the biases provided actually replace the predicted logits, whereas for biasing towards a specific output (e.g. ASCII-only) you would want to fully remove certain tokens from the set (and/or possibly adjust logits by adding/subtracing). I am probably better off re-implementing the TopPTopK sampler for this use case.
The above proposal could still be useful, but then there should be a way to indicate how the bias should be applied (either 'override', 'scale' or 'add'?).
This seems pretty reasonable to me. The current token bias implementation was implemented for parity with llama.cpp
upstream a few months ago; we can definitely better parameterise this.
What I'm thinking about is replacing the current TopPTopK
sampler with one that's composed of modular pieces that can be swapped out - that is, instead of baking in the repetition penalty and bias implementations, they can be swapped out or removed entirely to meet your requirements. That would also help address the naming issue we have here - you can swap in a biaser that overrides, scales or adds as required.
I'm still on holiday and not able to commit any substantive design time to this, but that's where my mind's at. Would that work for your use case?
Well eventually I just fully duplicated the TopPTopK
and added in my changes, and it works great (the samplers themselves are already easily swappable which is great)! So my use case is not as important anymore, but I guess making the API a bit smoother would still be a good thing.
In some cases it is useful to bias a model's output to a large set of tokens (i.e. all tokens that are just ASCII). Currently you'd have to filter all tokens and create a
Vec<(TokenId,f32)>
containing a potentially large number of elements.My suggested solution would be to replace
TokenBias
with a trait that allows the user to 'filter' tokens:A sample implementation allowing only ASCII:
A default implementation could be provided that accepts a
Vec<(TokenId, f32)>
and simply looks up the associated bias value (using aHashMap
behind the scenes would be a bit faster I presume).The sampler need only call the
bias_for_token
function once for each unique TokenId it encounters (and could possibly cache the result).