keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
758 stars 227 forks source link

Add scoring mode to MistralCausalLM #1521

Closed RyanMullins closed 5 months ago

RyanMullins commented 5 months ago

Adds the .score() function introduced with Gemma (https://github.com/keras-team/keras-nlp/pull/1448) to the MistralCausalLM model class. As with Gemma, this function supports a variety of interpretability use cases with Mistral by providing an API by which generated sequences can be scored (logits or loss) with gradient tracking on. Use cases include salience maps, patching, and training data attribution.

This is a direct port of the implementation and tests from Gemma, so hopefully that helps ease the review process.

mattdangerw commented 5 months ago

Thank you!