Closed jpainam closed 4 years ago
if i!=j and l_g==l_q: cs_ij_=(inputs[i]*inputs[j]).sum(0) #cross_similarity P*P cs_batch = cs_batch+cs_ij_.detach() cs_ij=torch.diag(cs_ij_) #cross_similarity P*P s_constr_= (ss_diff_norm[i]*ss_diff_norm[j]).sum(0) ss_batch=ss_batch+s_constr_.detach() W = cs_ij+(s_constr_-self.ss_mean) if use_matching_loss: x_optim, _= self.IQP_solver(W,self.lambd) x_optim = torch.from_numpy(x_optim).cuda(device_id)
i think the code is here!
if i!=j and l_g==l_q:
cs_ij_=(inputs[i]*inputs[j]).sum(0) #cross_similarity P*P
cs_batch = cs_batch+cs_ij_.detach()
cs_ij=torch.diag(cs_ij_) #cross_similarity P*P
s_constr_= (ss_diff_norm[i]*ss_diff_norm[j]).sum(0)
ss_batch=ss_batch+s_constr_.detach()
W = cs_ij+(s_constr_-self.ss_mean)
if use_matching_loss:
x_optim, _= self.IQP_solver(W,self.lambd)
x_optim = torch.from_numpy(x_optim).cuda(device_id)
else:
x_optim = torch.from_numpy(np.ones(W.size(0))).cuda(device_id)
matching_targets.append(x_optim)
matching_logit.append(matching_inputs[i]*matching_inputs[j])
loss += -torch.matmul(torch.matmul(x_optim,W),x_optim) + ((lambd[i]+lambd[j])*x_optim).sum()/2
loss = loss/len(matching_targets)
cs_batch=cs_batch/len(matching_targets)
ss_batch=ss_batch/len(matching_targets)
self.ss_mean=self.momentum*self.ss_mean+(1-self.momentum)*ss_batch
Thanks .Ok, i guess, \hat{M}_{i,j}
is represented as the
self.register_buffer('ss_mean', torch.zeros((part_num, part_num)))
which you update as
self.ss_mean=self.momentum*self.ss_mean+(1-self.momentum)*ss_batch
Yes
Thanks for your reply
Hi, in Equation 8.
You said, \hat{M} is the moving average of M. Please, what do you mean by that? And how did you implement \hat{M}. I can't really locate that part on your code. Thank you.