huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.11k stars 26.32k forks source link

How to obtain batch index of validation dataset? #33228

Open SoumiDas opened 2 weeks ago

SoumiDas commented 2 weeks ago

Hi,

I wanted to know how would we fetch the batch id/index of the eval dataset in preprocess_logits_for_metrics() ?

Thanks in advance!

nbroad1881 commented 1 week ago

You could put a global variable that keeps track of how many samples have been seen in preprocess_logits_for_metrics.


counter = 0
bs = 16

def preprocess_logits_for_metrics(logits, labels):
    global counter

    counter += logits.shape[0]

    batch_idx = counter % bs

Then you would have to calculate the batch idx based on the value of counter. You would also need to reset count after completing evaluation.