Open SoumiDas opened 2 weeks ago
You could put a global variable that keeps track of how many samples have been seen in preprocess_logits_for_metrics
.
counter = 0
bs = 16
def preprocess_logits_for_metrics(logits, labels):
global counter
counter += logits.shape[0]
batch_idx = counter % bs
Then you would have to calculate the batch idx based on the value of counter. You would also need to reset count after completing evaluation.
Hi,
I wanted to know how would we fetch the batch id/index of the eval dataset in
preprocess_logits_for_metrics()
?Thanks in advance!