QData / C-Tran

General Multi-label Image Classification with Transformers
MIT License
244 stars 43 forks source link

Question about the diag_mask in CTran.py file #23

Closed iamstarlee closed 1 year ago

iamstarlee commented 1 year ago

output = self.output_linear(label_embeddings) diag_mask = torch.eye(output.size(1)).unsqueeze(0).repeat(output.size(0),1,1).cuda() output = (output*diag_mask).sum(-1)

I'm wondering what this diag_mask is playing in output value? Why should the output multiple a diagonal matrix? Your kind reply is very appreciated!

jacklanchantin commented 1 year ago

This is a method to obtain the logits for each individual class. The output is of dim LxL, by multiplying by diag_mask and summing, we get the Lx1 vector.

iamstarlee commented 1 year ago

Thanks your attention! I know this is calculating the logits for each individual class, but I don't understand why only leave the diagonal data of the output rather than summing directly

jacklanchantin commented 1 year ago

The elements on the diagonal correspond to those individual classes. Otherwise you are summing over all classes and they end up being the same values.

iamstarlee commented 1 year ago

Alright! Thank you very much!! you are so nice! your paper inspired me a lot, it is really a good job! I am trying to classify sewer defect using C-Tran, but the recall scores are not satisfactory, hope I can get solution the next second.