ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.79k stars 823 forks source link

Classification Example #894

Closed vjagannath786 closed 1 month ago

vjagannath786 commented 1 month ago

Looking for some examples where LLM models like llama or Gemma is used for classification. And, also if it can return the logits.

awni commented 1 month ago

I don't know about examples where they are used for classification but if any come across our path we will add them here.

You can get the logprobs by using generate_step directly:

def generate_with_logits(
    model: nn.Module,
    tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
    prompt: str,
    max_tokens: int = 100,
    **kwargs,
):
    if not isinstance(tokenizer, TokenizerWrapper):
        tokenizer = TokenizerWrapper(tokenizer)

    prompt_tokens = mx.array(tokenizer.encode(prompt))
    detokenizer = tokenizer.detokenizer

    detokenizer.reset()
    all_logprobs = []

    for (token, logprobs), n in zip(
        generate_step(prompt_tokens, model, **kwargs),
        range(max_tokens),
    ):
        if token == tokenizer.eos_token_id:
            break
        detokenizer.add_token(token)
        all_logprobs.append(logprobs)

    detokenizer.finalize()

    return detokenizer.text, all_logprobs