clemsgrs / hipt

Re-implementation of HIPT
18 stars 6 forks source link

Implementation of pad mini-batch wise #8

Closed bryanwong17 closed 1 year ago

bryanwong17 commented 1 year ago

Hi, since HIPT in the 3rd stage is trained only on train_batch_size=1, I tried padding the sequences with the highest number of regions in mini-batch, as follows:

collate_features function

` from torch.nn.utils.rnn import pad_sequence

def collate_features(batch, label_type: str = "int"): idx = torch.LongTensor([item[0] for item in batch])

# implement pad mini-batch wise which pads the current mini-batch using
# the biggest sequence length from that mini-batch
feature = pad_sequence([item[1] for item in batch], batch_first=True, padding_value=0)

if label_type == "float":
    label = torch.FloatTensor([item[2] for item in batch])
elif label_type == "int":
    label = torch.LongTensor([item[2] for item in batch])
return [idx, feature, label]

`

GlobalHIPT forward function

` def forward(self, x):

    # x = [M, 192]
    x = self.global_phi(x)

    # in nn.TransformerEncoderLayer, batch_first defaults to False
    # hence, input is expected to be of shape (seq_length, batch, emb_size)
    # x = self.global_transformer(x.unsqueeze(1)).squeeze(1)
    # att, x = self.global_attn_pool(x)
    # att = torch.transpose(att, 1, 0)
    # att = F.softmax(att, dim=1)
    # x_att = torch.mm(att, x)
    # x_wsi = self.global_rho(x_att)

    # logits = self.classifier(x_wsi)

    # use pad mini-batch wise
    x = self.global_transformer(x.transpose(0, 1))
    att, x = self.global_attn_pool(x) # att: [M, B, 1], x: [M, B, 192]
    att = torch.transpose(att, 0, 1) # att: [B, M, 1]
    att = torch.transpose(att, 1, 2) # att: [B, 1, M]
    att = F.softmax(att, dim=-1)
    x = torch.transpose(x, 0, 1) # x: [B, M, 192]
    x_att = torch.bmm(att, x) # x_att: [B, 1, 192]
    x_wsi = self.global_rho(x_att.squeeze(1)) # x_wsi: [B, 192]
    logits = self.classifier(x_wsi) # logits: [B, num_classes]

    return logits

`

I have tested and it can work with train_batch_size>1. Hope it helps and please let me know if you find any errors in my implementation

bryanwong17 commented 1 year ago

When I tried a bigger batch size, the result was worse than when I used batch_size=1. If you could point out the problem, I would appreciate it