Closed ShiCrazy closed 1 year ago
def inst_eval_out(self, A, h, classifier): device = h.device if len(A.shape) == 1: A = A.view(1, -1) top_p_ids = torch.topk(A, self.k_sample)[1][-1] top_p = torch.index_select(h, dim=0, index=top_p_ids) p_targets = self.create_negative_targets(self.k_sample, device) logits = classifier(top_p) p_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1) instance_loss = self.instance_loss_fn(logits, p_targets) return instance_loss, p_preds, p_targets
p_targets = self.create_negative_targets(self.k_sample, device)
I don't understand why the class method is used to create a negative targets when 'top_p_ids = torch.topk(A, self.k_sample)[1][-1]', especially when I compare the method 'inst_eval_out' with the method 'inst_eval'. Could you tell me why?
def inst_eval_out(self, A, h, classifier): device = h.device if len(A.shape) == 1: A = A.view(1, -1) top_p_ids = torch.topk(A, self.k_sample)[1][-1] top_p = torch.index_select(h, dim=0, index=top_p_ids) p_targets = self.create_negative_targets(self.k_sample, device) logits = classifier(top_p) p_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1) instance_loss = self.instance_loss_fn(logits, p_targets) return instance_loss, p_preds, p_targets
p_targets = self.create_negative_targets(self.k_sample, device)
I don't understand why the class method is used to create a negative targets when 'top_p_ids = torch.topk(A, self.k_sample)[1][-1]', especially when I compare the method 'inst_eval_out' with the method 'inst_eval'. Could you tell me why?