nlp-with-transformers / notebooks

Jupyter notebooks for the Natural Language Processing with Transformers book
https://transformersbook.com/
Apache License 2.0
3.91k stars 1.22k forks source link

how to get loss returned along with logit in Tensorflow #79

Open tansaku opened 1 year ago

tansaku commented 1 year ago

Information

The question or comment is about chapter:

Question or comment

The book shows a really interesting example of getting the loss returned along with the predicted class probability, in the "Error Analysis" section of chapter 2:

Before moving on, we should investigate our model’s predictions a little bit further. A simple yet powerful technique is to sort the validation samples by the model loss. When we pass the label during the forward pass, the loss is automatically calculated and returned. Here’s a function that returns the loss along with the predicted label:

from torch.nn.functional import cross_entropy

def forward_pass_with_label(batch):
    # Place all input tensors on the same device as the model
    inputs = {k:v.to(device) for k,v in batch.items()
              if k in tokenizer.model_input_names}

    with torch.no_grad():
        output = model(**inputs)
        pred_label = torch.argmax(output.logits, axis=-1)

        loss = cross_entropy(output.logits, batch["label"].to(device),
                             reduction="none")
    # Place outputs on CPU for compatibility with other dataset columns
    return {"loss": loss.cpu().numpy(),
            "predicted_label": pred_label.cpu().numpy()}

Does anyone have any idea how to do similar for a tensorflow based approach?

I've been reading the documentation for the TF model predict function, but can't immediately see anything that would correspond to the same https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict although to be honest, I'm not really quite following what the example code is doing ...

is it taking the validation dataset, and re-predicting the output for each item in it, and then calculating the loss as cross entroy as function of the output logits and what the correct label should have been ...?

so to do that with TF I'd need to take the validation set and do something similar ...?