Albert0147 / AaD_SFDA

Code for our NeurIPS 2022 (spotlight) paper 'Attracting and Dispersing: A Simple Approach for Source-free Domain Adaptation'
61 stars 6 forks source link

A question about implementation details #3

Closed rickyang1114 closed 1 year ago

rickyang1114 commented 1 year ago

Dear author,

I'm an undergraduate student who is quite interested in SFDA, and I have been following your work from G-SFDA, NRC to AaD. I really appreciate your works and the wonderful performance they achieve.

Recently, I'm trying to understand your work of AaD. However, I become a little confused about some implementation details as I conbine your paper and source code.

From my understanding, the B_i in div term contains all other items in a mini batch except those in C_i. In other words, items that are k nearest neighbors should be excluded from B_i, as presented in the paper. However, in your code I copied below:

mask = torch.ones((inputs_target.shape[0], inputs_target.shape[0]))
diag_num = torch.diag(mask)
mask_diag = torch.diag_embed(diag_num)
mask = mask - mask_diag
if args.noGRAD:
    copy = softmax_out.T.detach().clone()
else:
    copy = softmax_out.T  # .detach().clone()  #
dot_neg = softmax_out @ copy  # batch x batch
dot_neg = (dot_neg * mask.cuda()).sum(-1)  # batch
neg_pred = torch.mean(dot_neg)
loss += neg_pred * alpha

it seems that only diagonal entries in mask are set to 0, rather than k nearest neighbors.

I suppose I must have some misunderstandings, so I create the issue, hoping to get your answer. I would appreciate it if you could answer my qusetion. Looking forward to your reply.

Albert0147 commented 1 year ago

Hi, actually we mentioned in page 4 that we use all other features in the mini-batch as B_i, and also explained that there may exist the situation that B_i has intersection with C_i (we think it is fine, as it should not be frequent).

However, you can try it with excluding the potential neighbors in B_i, it may improve the performance.

rickyang1114 commented 1 year ago

My doubts are clarified. Thanks for your reply.