mahmoodlab / CLAM

Data-efficient and weakly supervised computational pathology on whole slide images - Nature Biomedical Engineering
http://clam.mahmoodlab.org
GNU General Public License v3.0
1.02k stars 340 forks source link

A little question in /models/model_clam.py #180

Closed ShiCrazy closed 1 year ago

ShiCrazy commented 1 year ago

def inst_eval_out(self, A, h, classifier): device = h.device if len(A.shape) == 1: A = A.view(1, -1) top_p_ids = torch.topk(A, self.k_sample)[1][-1] top_p = torch.index_select(h, dim=0, index=top_p_ids) p_targets = self.create_negative_targets(self.k_sample, device) logits = classifier(top_p) p_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1) instance_loss = self.instance_loss_fn(logits, p_targets) return instance_loss, p_preds, p_targets

p_targets = self.create_negative_targets(self.k_sample, device)

I don't understand why the class method is used to create a negative targets when 'top_p_ids = torch.topk(A, self.k_sample)[1][-1]', especially when I compare the method 'inst_eval_out' with the method 'inst_eval'. Could you tell me why?