keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
740 stars 218 forks source link

create local variable per_token_loss in score method to global. So that we can modify loss function. #1539

Closed deveshklt closed 3 months ago

deveshklt commented 3 months ago

Is your feature request related to a problem? Please describe.

I want to use custom loss function for per_token_loss_fn. So I want to make it it accessible outside the class to update.

Describe the solution you'd like

from this per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction="none" ) to this self.per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction="none" ) Describe alternatives you've considered

Additional context

mattdangerw commented 3 months ago

@deveshklt can you explain more of your overall use case? E.g. what other loss you would like to pass, and for what end purpose?

The score() function is new functionality being added by @RyanMullins, in particular for interpretability applications (though it's a general function with many uses). Feedback welcome as we are building it out, but understand the user journeys here will help.

deveshklt commented 3 months ago

My use case is comparing two sentences and calculate divergence between them. I want to use Kullback-Leibler (KL) Divergence Loss for this. I want to use some different per token loss function for this.

RyanMullins commented 3 months ago

Hi @deveshklt! As @mattdangerw mentioned, I've added a .score() API to the KerasNLP's implementations of Gemma, Llama, Mistral, and GPT-2. This function is inspired by the scoring mode in Google's T5X modeling framework (GitHub, paper); you provide a tokenized representation of a sequence and this API computes either the logits or the per-token loss for that sequence from the model, depending on the value of scoring_mode. If run in "logits" mode (the default), you can compute any custom loss that you like from the tensor this API returns (this was an intentional design choice to support use cases like yours).

lm = keras_nlp.models.CausalLM.from_preset("some_preset")

generations = ... # I assume you already have a list of strings here.

preprocessed = lm.preprocessor.generate_preprocess(generations)
generation_ids = preprocessed["token_ids"]
padding_mask = preprocessed["padding_mask"]

logits = lm.score(
    token_ids=generation_ids,
    padding_mask=padding_mask
)

model_loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=None)
# Compute loss with model_loss_fn

kldiv_loss_fn = keras.losses.KLDivergence(...)
# Compute loss with kldiv_loss_fn

# ...and so on with the other loss functions you're exploring.

I assume that since you're interested in KL Divergence, you have a dataset with some ground truth that you can use as the value of y_ture. In which case, you can use the .score() API to compute the logits for the ground truth and generation sequences and than pass those into the loss function as, for example, kldiv_loss_fn(gt_logits, gen_logits).

deveshklt commented 3 months ago

Thank you @RyanMullins for the explanation.