Open tsongsun opened 9 months ago
When I run the main.py, encounter an error, and then I print the shapes of preds and labels. I find tha t the problem may appear in the code of confusion_matrix[phase].add(preds, labels).
looks like torch.mean made it so that preds is 0-dimensional, which triggers this error in confusionmeter. I changed the outputs line to
outputs = torch.mean(outputs).unsqueeze(0)
and it works now
When I run the main.py, encounter an error, and then I print the shapes of preds and labels. I find tha t the problem may appear in the code of confusion_matrix[phase].add(preds, labels).