Closed rookiecm closed 3 years ago
Thanks for sharing you code. I am a little bit confused about the accuracy computation in L69-70 in wrapper.py:
acc_i = (torch.argmax(image_logits) == ground_truth).sum() acc_t = (torch.argmax(image_logits.t()) == ground_truth).sum()
acc_i = (torch.argmax(image_logits) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t()) == ground_truth).sum()
It seems that torch.argmax retures the max value index accross all dimensions while _groundtruth is with each row or column. Should we change to?
acc_i = (torch.argmax(image_logits, 0) == ground_truth).sum() acc_t = (torch.argmax(image_logits.t(), 0) == ground_truth).sum().
acc_i = (torch.argmax(image_logits, 0) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t(), 0) == ground_truth).sum()
Thanks.
It looks like it! Thanks for catching this bug. I'll send a fix right now.
Made an update and also condensed it into a single term due to the diagonal being the same. Should look good now!
Thanks for sharing you code. I am a little bit confused about the accuracy computation in L69-70 in wrapper.py:
acc_i = (torch.argmax(image_logits) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t()) == ground_truth).sum()
It seems that torch.argmax retures the max value index accross all dimensions while _groundtruth is with each row or column. Should we change to?
acc_i = (torch.argmax(image_logits, 0) == ground_truth).sum()
acc_t = (torch.argmax(image_logits.t(), 0) == ground_truth).sum()
.Thanks.