microsoft / robustdg

Toolkit for building machine learning models that generalize to unseen domains and are robust to privacy and other attacks.
MIT License
172 stars 30 forks source link

Implementation of match-dg #18

Closed YiDongOuYang closed 3 years ago

YiDongOuYang commented 3 years ago

https://github.com/microsoft/robustdg/blob/514a3d92c8bf55d839a36ed0af654a63480dca8c/algorithms/match_dg.py In train_ctrphase, we compute the Mutual information between z{di} and z{d_j} in the feature space. However, in train_ermphase, we compute the Mutual information between logits{di} and logits{d_j} in the logits space.

Am I understand correctly? Why it is different?

divyat09 commented 3 years ago

Thanks for your question! @YiDongOuYang

In the train_ctr_phase, the objective is to learn a matching function (not classification) using iterative contrastive learning, and hence we discard the final classification layer associated with ResNet-50 and use the 512-dimensional features (input to classification layer for downstream tasks). Therefore we need to operate in the feature space during the train_ctr_phase, as the task is self-supervised (contrastive) representation learning.

While in the train_erm_phase, the task is to learn a classifier that can generalize across domains, hence we use the empirical risk minimization loss, along with our object-based matching loss (with objects inferred using the matching function in the train_ctr_phase). However, the object-based matching loss can be applied at any intermediate representation layer during the train_erm_phase (as described in the paper, Suppl. implementation details ). Therefore, it's not necessary to work in the logit space to compute the matching penalty, and for simplicity, we use the logits.

Note: The mutual information as stated in your question has been referred to as object-based matching loss in my reply above.

YiDongOuYang commented 3 years ago

Thank you very much for your clear illumination! It solves my concern.