szc19990412 / TransMIL

TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification
325 stars 72 forks source link

Where is the multiple instance learning exactly? #11

Closed LXYTSOS closed 2 years ago

LXYTSOS commented 2 years ago

I found the same thing, and it seems directly use the slide feature and the label for training, shown as follows:

    def __getitem__(self, idx):
        slide_id = self.data[idx]
        label = int(self.label[idx])
        full_path = Path(self.feature_dir) / f'{slide_id}.pt'
        features = torch.load(full_path)

        #----> shuffle
        if self.shuffle == True:
            index = [x for x in range(features.shape[0])]
            random.shuffle(index)
            features = features[index]

        return features, label
    def training_step(self, batch, batch_idx):
        #---->inference
        data, label = batch
        results_dict = self.model(data=data, label=label)
        logits = results_dict['logits']
        Y_prob = results_dict['Y_prob']
        Y_hat = results_dict['Y_hat']

        #---->loss
        loss = self.loss(logits, label)

        #---->acc log
        Y_hat = int(Y_hat)
        Y = int(label)
        self.data[Y]["count"] += 1
        self.data[Y]["correct"] += (Y_hat == Y)

        torch.cuda.empty_cache()
        return {'loss': loss} 
    def forward(self, **kwargs):

        h = kwargs['data'].float() #[B, n, 1024]

        h = self._fc1(h) #[B, n, 512]

        #---->pad
        H = h.shape[1]
        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
        add_length = _H * _W - H
        h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512]

        #---->cls_token
        B = h.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
        h = torch.cat((cls_tokens, h), dim=1)

        #---->Translayer x1
        h = self.layer1(h) #[B, N, 512]

        #---->PPEG
        h = self.pos_layer(h, _H, _W) #[B, N, 512]

        #---->Translayer x2
        h = self.layer2(h) #[B, N, 512]

        #---->cls_token
        h = self.norm(h)[:,0]

        #---->predict
        logits = self._fc2(h) #[B, n_classes]
        Y_hat = torch.argmax(logits, dim=1)
        Y_prob = F.softmax(logits, dim = 1)
        results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
        return results_dict

If I understand correctly, I didn't found any code relevant with multiple instance learning. Please let me know if I miss something.

Originally posted by @LXYTSOS in https://github.com/szc19990412/TransMIL/issues/9#issuecomment-1081516320

laico12345 commented 2 months ago

Hi I noticed that you set this question as completed status, but I still suffering about this, could you tell where is the code about MIL?