yfzhang114 / AdaNPC

This is an official PyTorch implementation of the ICML 2023 paper AdaNPC and SIGKDD paper DRM.
MIT License
77 stars 6 forks source link

Lack the implementation of KNN and EM #1

Closed libo-huang closed 1 year ago

libo-huang commented 1 year ago

line 149 in the class KNN(Algorithm) is,

https://github.com/yfzhang114/AdaNPC/blob/c280d6be2b524a895a554a7bb650ae99f5b11eea/domainbed/algorithms.py#L147-L149

Eq.(4) in the paper is,

image

Thanks for your project.

I have a question about Where is the detailed $L_{KNN}$, concretely, the codes of Eq.(4) (see above) and EM algorithm? Detailed questions are shown below:

Although line 149 in the class KNN(Algorithm) (see above) , replaces F.cross_entropy(a, b) with F.nll_loss(torch.log(a), b), I think it still does not achieve Eq.(4) function.

Because F.nll_loss(torch.log(a), b) and F.cross_entropy(a, b) have the same results, referring to How is Pytorch’s Cross Entropy function related to softmax, log softmax, and NLL.

Besides, there seems to be no implementation of EM in this project.

yfzhang114 commented 1 year ago
  1. EM means we update the memory and train the network alternatively. The update of the whole memory bank is detailed in https://github.com/yfzhang114/AdaNPC/blob/c280d6be2b524a895a554a7bb650ae99f5b11eea/domainbed/scripts/train.py#L250, which is just we said in the paper We only periodically update $B_{k,\theta,D_S}$ and keep them fixed for the remaining time,
  2. Here the implementation is similar to contrastive learning, where we use classification tasks instead of minimizing representation similarity directly.