135 rel_indices = (num_relevant != 0).nonzero().squeeze()
136 rel_count = num_relevant[rel_indices]
--> 138 if rel_indices.shape[0] > 0:
139 for index, k in enumerate(ks):
140 rel_labels = topk_labels[rel_indices, : int(k)]
IndexError: tuple index out of range
Goals :soccer:
Implementation Details :construction:
https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/23d5e3ba73b3d490400e45bf4feba94ed473432f/transformers4rec/torch/ranking_metric.py#L134 The squeeze() function removes all dimensions when batch size is 1. I've added
dim=1
argument to keep dimension along batch.Testing Details :mag:
Test code for error reproduction
Error message