huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.87k stars 25.81k forks source link

compute_metric(eval_pred) in trainer is not mini-batch #31667

Open SamYuen101234 opened 1 month ago

SamYuen101234 commented 1 month ago

I am trying to implement a custom compute metric for trainer. The logits and labels are numpy array of the full evaluation data, however, my evaluation data input has the size (1000, 43, 50257). The computation can't be done in a 24GB L4 GPU on colab. Any way to load the data in mini batch like using dataloader instead of given a full numpy array.

`# eval_pred is all the valid data not only the mini-batch def compute_metrics(eval_pred): accuracy_metric = load_metric("accuracy") logits, labels = eval_pred

# Get predictions (next word prediction)
predictions = np.argmax(logits, axis=-1)

# Shift labels to the left
labels_shifted = labels[:, 1:].flatten()
predictions_shifted = predictions[:, :-1].flatten()

# Create an attention mask based on labels (assuming -100 is padding)
attention_mask_shifted = (labels[:, 1:] != -100).flatten()

# Remove padding tokens using attention mask
predictions_shifted = predictions_shifted[attention_mask_shifted]
labels_shifted = labels_shifted[attention_mask_shifted]

# Compute accuracy
if len(labels_shifted) == 0:
    return {"accuracy": 0.0}
accuracy = accuracy_metric.compute(predictions=predictions_shifted, references=labels_shifted)

return {"accuracy": accuracy["accuracy"]}`
amyeroberts commented 1 month ago

Hi @SamYuen101234, thanks for raising an issue!

This is a question best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

github-actions[bot] commented 19 hours ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.