Zasder3 / train-CLIP

A PyTorch Lightning solution to training OpenAI's CLIP from scratch.
MIT License
653 stars 78 forks source link

About the accuracy computation. #6

Closed rookiecm closed 3 years ago

rookiecm commented 3 years ago

Thanks for sharing you code. I am a little bit confused about the accuracy computation in L69-70 in wrapper.py:

acc_i = (torch.argmax(image_logits) == ground_truth).sum() acc_t = (torch.argmax(image_logits.t()) == ground_truth).sum()

It seems that torch.argmax retures the max value index accross all dimensions while _groundtruth is with each row or column. Should we change to?

acc_i = (torch.argmax(image_logits, 0) == ground_truth).sum() acc_t = (torch.argmax(image_logits.t(), 0) == ground_truth).sum().

Thanks.

Zasder3 commented 3 years ago

It looks like it! Thanks for catching this bug. I'll send a fix right now.

Zasder3 commented 3 years ago

Made an update and also condensed it into a single term due to the diagonal being the same. Should look good now!