Thanos-DB / FullyConvolutionalTransformer

[WACV 2023] Official implementation of The Fully Convolutional Transformer for Medical Image Segmentation
https://chaitanya-kaul.github.io/
116 stars 13 forks source link

pytorch model version #3

Closed xiehou-design closed 1 year ago

xiehou-design commented 1 year ago

Thanks your work! But I find some problems in pytorch's model code. I used your code to do a bineary semantic segmentation experiment, and the effect was very bad. Even if all the predictions were background classes, the model would not converge at all. I ensured that the data processing in my experiment was effective, and I replaced it with the simplest unet model to ensure the accuracy of my data processing process. But the unet model get correct result.

here are my FCT model:

from torch import nn
import torch
import torch.nn.functional as F
import numpy as np

class Attention(nn.Module):
    def __init__(self,
                 channels,
                 num_heads,
                 proj_drop=0.0,
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_q=1,
                 padding_kv=1,
                 attention_bias=True
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.num_heads = num_heads
        self.proj_drop = proj_drop

        self.conv_q = nn.Conv2d(channels, channels, kernel_size, stride_q, padding_q, bias=attention_bias,
                                groups=channels)
        self.layernorm_q = nn.LayerNorm(channels, eps=1e-5)
        self.conv_k = nn.Conv2d(channels, channels, kernel_size, stride_kv, padding_kv, bias=attention_bias,
                                groups=channels)
        self.layernorm_k = nn.LayerNorm(channels, eps=1e-5)
        self.conv_v = nn.Conv2d(channels, channels, kernel_size, stride_kv, padding_kv, bias=attention_bias,
                                groups=channels)
        self.layernorm_v = nn.LayerNorm(channels, eps=1e-5)

        self.attention = nn.MultiheadAttention(embed_dim=channels,
                                               bias=attention_bias,
                                               num_heads=1)

    def _build_projection(self, x, qkv):

        if qkv == "q":
            x1 = F.relu(self.conv_q(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_q(x1)
            proj = x1.permute(0, 3, 1, 2)
        elif qkv == "k":
            x1 = F.relu(self.conv_k(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_k(x1)
            proj = x1.permute(0, 3, 1, 2)
        elif qkv == "v":
            x1 = F.relu(self.conv_v(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_v(x1)
            proj = x1.permute(0, 3, 1, 2)
        else:
            proj = None
            ValueError('qkv is error')

        return proj

    def forward_conv(self, x):
        q = self._build_projection(x, "q")
        k = self._build_projection(x, "k")
        v = self._build_projection(x, "v")

        return q, k, v

    def forward(self, x):
        q, k, v = self.forward_conv(x)
        q = q.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
        k = k.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
        v = v.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
        q = q.permute(0, 2, 1)
        k = k.permute(0, 2, 1)
        v = v.permute(0, 2, 1)
        # 因为multi-head会有输出有两个值,所以就会后面只取x[0]
        x1 = self.attention(query=q, value=v, key=k, need_weights=False)

        x1 = x1[0].permute(0, 2, 1)
        x1 = x1.view(x1.shape[0], x1.shape[1], np.sqrt(x1.shape[2]).astype(int), np.sqrt(x1.shape[2]).astype(int))
        x1 = F.dropout(x1, self.proj_drop)

        return x1

class Transformer(nn.Module):

    def __init__(self,
                 channels,
                 num_heads,
                 proj_drop=0.0,
                 attention_bias=True,
                 padding_q=1,
                 padding_kv=1,
                 stride_kv=1,
                 stride_q=1):
        super().__init__()

        self.attention_output = Attention(channels=channels,
                                          num_heads=num_heads,
                                          proj_drop=proj_drop,
                                          padding_q=padding_q,
                                          padding_kv=padding_kv,
                                          stride_kv=stride_kv,
                                          stride_q=stride_q,
                                          attention_bias=attention_bias,
                                          )

        self.conv1 = nn.Conv2d(channels, channels, 3, 1, padding=1)
        self.layernorm = nn.LayerNorm(self.conv1.out_channels, eps=1e-5)
        self.wide_focus = Wide_Focus(channels, channels)

    def forward(self, x):
        x1 = self.attention_output(x)
        x1 = self.conv1(x1)
        x2 = torch.add(x1, x)
        # 因为是layer normalization ,所以要执行两次permute交换特征维度
        x3 = x2.permute(0, 2, 3, 1)
        x3 = self.layernorm(x3)
        x3 = x3.permute(0, 3, 1, 2)
        x3 = self.wide_focus(x3)
        x3 = torch.add(x2, x3)
        return x3

class Wide_Focus(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 padding_number=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number * 2, dilation=2)
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number * 3, dilation=3)
        self.conv4 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.gelu(x1)
        x1 = F.dropout(x1, 0.1)
        x2 = self.conv2(x)
        x2 = F.gelu(x2)
        x2 = F.dropout(x2, 0.1)
        x3 = self.conv3(x)
        x3 = F.gelu(x3)
        x3 = F.dropout(x3, 0.1)
        added = torch.add(x1, x2)
        added = torch.add(added, x3)
        x_out = self.conv4(added)
        x_out = F.gelu(x_out)
        x_out = F.dropout(x_out, 0.1)
        return x_out

class BlockEncoderBottleneck(nn.Module):
    def __init__(self, blk, in_channels, out_channels, att_heads, dpr, padding_number=1):
        super().__init__()
        self.blk = blk
        if (self.blk == "first") or (self.blk == "bottleneck"):
            self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number)
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=padding_number)
            self.trans = Transformer(out_channels, att_heads, dpr)
        elif (self.blk == "second") or (self.blk == "third") or (self.blk == "fourth"):
            self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
            self.conv1 = nn.Conv2d(1, in_channels, 3, 1, padding=padding_number)
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=padding_number)
            self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=padding_number)
            self.trans = Transformer(out_channels, att_heads, dpr)

    def forward(self, x, scale_img=None):
        if (self.blk == "first") or (self.blk == "bottleneck"):
            x1 = x.permute(0, 2, 3, 1)
            x1 = self.layernorm(x1)
            x1 = x1.permute(0, 3, 1, 2)
            x1 = F.relu(self.conv1(x1))
            x1 = F.relu(self.conv2(x1))
            x1 = F.dropout(x1, 0.3)
            x1 = F.max_pool2d(x1, (2, 2))
            out = self.trans(x1)
            # without skip
        elif (self.blk == "second") or (self.blk == "third") or (self.blk == "fourth"):
            x1 = x.permute(0, 2, 3, 1)
            x1 = self.layernorm(x1)
            x1 = x1.permute(0, 3, 1, 2)
            x1 = torch.cat((F.relu(self.conv1(scale_img)), x1), dim=1)
            x1 = F.relu(self.conv2(x1))
            x1 = F.relu(self.conv3(x1))
            x1 = F.dropout(x1, 0.3)
            x1 = F.max_pool2d(x1, (2, 2))
            out = self.trans(x1)
            # with skip
        else:
            out = None
            ValueError("blk is error")
        return out

class BlockDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, att_heads, dpr, padding_number=1):
        super().__init__()
        self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number)
        self.conv2 = nn.Conv2d(out_channels * 2, out_channels, 3, 1, padding=padding_number)
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=padding_number)
        self.trans = Transformer(out_channels, att_heads, dpr)

    def forward(self, x, skip):
        x1 = x.permute(0, 2, 3, 1)
        x1 = self.layernorm(x1)
        x1 = x1.permute(0, 3, 1, 2)
        x1 = self.upsample(x1)
        x1 = F.relu(self.conv1(x1))
        x1 = torch.cat([skip, x1], dim=1)
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x1 = F.dropout(x1, 0.3)
        out = self.trans(x1)
        return out

class DsOut(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes=2, padding_number=1):
        super().__init__()
        self.num_classes = num_classes
        self.upsample = nn.Upsample(scale_factor=2)
        self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, padding=padding_number)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=padding_number)
        self.conv3 = nn.Conv2d(out_channels, self.num_classes, 3, 1, padding=padding_number)

    def forward(self, x):
        x1 = self.upsample(x)
        x1 = x1.permute(0, 2, 3, 1)
        x1 = self.layernorm(x1)
        x1 = x1.permute(0, 3, 1, 2)
        x1 = F.relu(self.conv1(x1))
        x1 = F.relu(self.conv2(x1))
        out = F.sigmoid(self.conv3(x1))

        return out

class FCT(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        # attention heads and filters per block
        att_heads = [2, 2, 2, 2, 2, 2, 2, 2, 2]
        filters = [8, 16, 32, 64, 128, 64, 32, 16, 8]

        # number of blocks used in the model
        blocks = len(filters)

        stochastic_depth_rate = 0.0

        # probability for each block
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]

        self.drp_out = 0.3

        # Multi-scale input
        self.scale_img = nn.AvgPool2d(2, 2)

        # model
        self.block_1 = BlockEncoderBottleneck("first", 1, filters[0], att_heads[0], dpr[0])
        self.block_2 = BlockEncoderBottleneck("second", filters[0], filters[1], att_heads[1], dpr[1])
        self.block_3 = BlockEncoderBottleneck("third", filters[1], filters[2], att_heads[2], dpr[2])
        self.block_4 = BlockEncoderBottleneck("fourth", filters[2], filters[3], att_heads[3], dpr[3])
        self.block_5 = BlockEncoderBottleneck("bottleneck", filters[3], filters[4], att_heads[4], dpr[4])
        self.block_6 = BlockDecoder(filters[4], filters[5], att_heads[5], dpr[5])
        self.block_7 = BlockDecoder(filters[5], filters[6], att_heads[6], dpr[6])
        self.block_8 = BlockDecoder(filters[6], filters[7], att_heads[7], dpr[7])
        self.block_9 = BlockDecoder(filters[7], filters[8], att_heads[8], dpr[8])

        self.ds7 = DsOut(filters[6], 4, self.num_classes)
        self.ds8 = DsOut(filters[7], 4, self.num_classes)
        self.ds9 = DsOut(filters[8], 4, self.num_classes)

    def forward(self, x):
        # Multi-scale input
        scale_img_2 = self.scale_img(x)
        scale_img_3 = self.scale_img(scale_img_2)
        scale_img_4 = self.scale_img(scale_img_3)

        x = self.block_1(x)
        # print(f"Block 1 out -> {list(x.size())}")
        skip1 = x
        x = self.block_2(x, scale_img_2)
        # print(f"Block 2 out -> {list(x.size())}")
        skip2 = x
        x = self.block_3(x, scale_img_3)
        # print(f"Block 3 out -> {list(x.size())}")
        skip3 = x
        x = self.block_4(x, scale_img_4)
        # print(f"Block 4 out -> {list(x.size())}")
        skip4 = x
        x = self.block_5(x)
        # print(f"Block 5 out -> {list(x.size())}")
        x = self.block_6(x, skip4)
        # print(f"Block 6 out -> {list(x.size())}")
        x = self.block_7(x, skip3)
        # print(f"Block 7 out -> {list(x.size())}")
        skip7 = x
        x = self.block_8(x, skip2)
        # print(f"Block 8 out -> {list(x.size())}")
        skip8 = x
        x = self.block_9(x, skip1)
        # print(f"Block 9 out -> {list(x.size())}")
        skip9 = x

        out7 = self.ds7(skip7)
        # print(f"DS 7 out -> {list(out7.size())}")
        out8 = self.ds8(skip8)
        # print(f"DS 8 out -> {list(out8.size())}")
        out9 = self.ds9(skip9)
        # print(f"DS 9 out -> {list(out9.size())}")

        return out7, out8, out9

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_normal(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

if __name__ == '__main__':
    fct = FCT(num_classes=2)
    print(fct)
    data = torch.rand((2, 1, 224, 224), dtype=torch.float)
    fct(data)
Thanos-DB commented 1 year ago

Hi xiehou-design,

thank you for taking an interest to our work! Before starting let me once again mention that the paper is based on the Tensorflow implementation. Having said that the Pytorch implementation should also work fine for you. Do you have any metrics to get a better feeling of what happens? What is your Dice (or whatever metrics you are using) for the FCT and the vanilla UNet? Is the vanilla Unet pretrained? Is your dataset freely available or a private one? What is the size of both models? Can you force FCT to overfit your dataset (you might consider using a smaller set)? Also, change the num_heads=1 to self.num_heads, it should help (you can do ablation studies to find what works best for you. You can start with 2s all over) Is your vanilla Unet also multi-input multi-output? Does the loss for the FCT reduce during training? Lastly, how many segmentation masks is your ground truth? It seems to me that you need to change your number of classes to 1.

BR Thanos

xiehou-design commented 1 year ago

Thank your replay. My semantic segmentation experiment uses pixel precision, recall and f1 as the metrics. I use crossentropy loss, the loss not decrease in FCT, but it normal in Unet. FCT and Unet are not pretrained. My experiment dataset follow the cell dataset, it very small and available. I am sure that FCT has no overfit , and even it believes that all targets are backgrounds. I change the num_heads=1 to num_heads=self.num_heads, but fct still does not work. FCT and Unet's input size is 224.

my experiment code in google drive

Thanos-DB commented 1 year ago

Did you try with the number of classes i suggested? Also, and more importantly your UNet works because you use no activation on the last layer and you get logits which is what your loss function expects. This is not the case with FCT. You should correct this one.

xiehou-design commented 1 year ago

Yes, I try it. I use the classes=1, and replace the crossentropy to BCEloss. But it still not work. Also, I try to replace DsOut module's the last convolution out = F.sigmoid(self.conv3(x1)) to out = self.conv3(x1), use the classes=2, and crossentropy. But it still not work. I try to find the problem more deeply. I think there may be a problem with the model's attention calculation, but I compare the version of TensorFlow code and find no obvious error.

Thanos-DB commented 1 year ago

I tried running your code and i get a lot of errors. Firstly it was mismatching dimensions. When i fixed this one i got errors concerning dtypes (float, long, etc). This happens regardless of the model because i tried both of them. This is my torch version: 1.13.1. I can tell you it is not the model because i have tested it and it produces nice segmentation masks (binary or not). I have also tested it inside frameworks from other people and it also works.

xiehou-design commented 1 year ago

I am sorry, mu torch version is: 1.8.1. I find some problems in this version, like multi-head Attention inputs query, key and value vector's shape. Now, I replace the original torch version and use with version 1.13.1. Thank you replay, I don't doubt the novelty of the FCT. I will modify my code to experiment.

Geekiter commented 1 year ago

I encountered the same problem, all the predicted classes are background classes, avg dsc is equal to 0. And my pytorch version is 1.13.0.

xiehou-design commented 1 year ago

I encountered the same problem, all the predicted classes are background classes, avg dsc is equal to 0. And my pytorch version is 1.13.0.

Also, me too. And I do not find the problem int the FCT model.

Thanos-DB commented 1 year ago

Hi Geekiter,

do you have any reproducible code?

Geekiter commented 1 year ago

I didn't change the model, I only added avg dsc calculation

import datetime
import os

import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
from torch.functional import F
# os.chdir("/home/thanos/code")
import wandb

from utils import *

# %% Get data
batch_size = 1
epochs = 10
# model+datetime
model_path = "./model.pth"
learning_rate = 0.0001
# ---- ACDC
# training

# ---- ACDC

acdc_data, _, _ = get_acdc(acdc_data_train)
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2))  # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2))  # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0])  # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1])  # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
train_dataloader = DataLoader(acdc_data, batch_size=batch_size)
# validation
acdc_data, _, _ = get_acdc(acdc_data_validation)

acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2))  # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2))  # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0])  # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1])  # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
validation_dataloader = DataLoader(acdc_data, batch_size=batch_size)
# testing
acdc_data, _, _ = get_acdc(acdc_data_test)

acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2))  # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2))  # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0])  # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1])  # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
test_dataloader = DataLoader(acdc_data, batch_size=batch_size)

# %% ######################################################################################
# create model

class Attention(nn.Module):
    def __init__(self,
                 channels,
                 num_heads,
                 proj_drop=0.0,
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv="same",
                 padding_q="same",
                 attention_bias=True
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.num_heads = num_heads
        self.proj_drop = proj_drop

        self.conv_q = nn.Conv2d(channels, channels, kernel_size, stride_q, padding_q, bias=attention_bias,
                                groups=channels)
        self.layernorm_q = nn.LayerNorm(channels, eps=1e-5)
        self.conv_k = nn.Conv2d(channels, channels, kernel_size, stride_kv, stride_kv, bias=attention_bias,
                                groups=channels)
        self.layernorm_k = nn.LayerNorm(channels, eps=1e-5)
        self.conv_v = nn.Conv2d(channels, channels, kernel_size, stride_kv, stride_kv, bias=attention_bias,
                                groups=channels)
        self.layernorm_v = nn.LayerNorm(channels, eps=1e-5)

        self.attention = nn.MultiheadAttention(embed_dim=channels,
                                               bias=attention_bias,
                                               batch_first=True,
                                               # dropout = 0.0,
                                               num_heads=1)  # num_heads=self.num_heads)

    def _build_projection(self, x, qkv):

        if qkv == "q":
            x1 = F.relu(self.conv_q(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_q(x1)
            proj = x1.permute(0, 3, 1, 2)
        elif qkv == "k":
            x1 = F.relu(self.conv_k(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_k(x1)
            proj = x1.permute(0, 3, 1, 2)
        elif qkv == "v":
            x1 = F.relu(self.conv_v(x))
            x1 = x1.permute(0, 2, 3, 1)
            x1 = self.layernorm_v(x1)
            proj = x1.permute(0, 3, 1, 2)

        return proj

    def forward_conv(self, x):
        q = self._build_projection(x, "q")
        k = self._build_projection(x, "k")
        v = self._build_projection(x, "v")

        return q, k, v

    def forward(self, x):
        q, k, v = self.forward_conv(x)
        q = q.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
        k = k.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
        v = v.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
        q = q.permute(0, 2, 1)
        k = k.permute(0, 2, 1)
        v = v.permute(0, 2, 1)
        x1 = self.attention(query=q, value=v, key=k, need_weights=False)

        x1 = x1[0].permute(0, 2, 1)
        x1 = x1.view(x1.shape[0], x1.shape[1], np.sqrt(x1.shape[2]).astype(int), np.sqrt(x1.shape[2]).astype(int))
        x1 = F.dropout(x1, self.proj_drop)

        return x1

class Transformer(nn.Module):

    def __init__(self,
                 # in_channels,
                 out_channels,
                 num_heads,
                 dpr,
                 proj_drop=0.0,
                 attention_bias=True,
                 padding_q="same",
                 padding_kv="same",
                 stride_kv=1,
                 stride_q=1):
        super().__init__()

        self.attention_output = Attention(channels=out_channels,
                                          num_heads=num_heads,
                                          proj_drop=proj_drop,
                                          padding_q=padding_q,
                                          padding_kv=padding_kv,
                                          stride_kv=stride_kv,
                                          stride_q=stride_q,
                                          attention_bias=attention_bias,
                                          )

        self.conv1 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.layernorm = nn.LayerNorm(self.conv1.out_channels, eps=1e-5)
        self.wide_focus = Wide_Focus(out_channels, out_channels)

    def forward(self, x):
        x1 = self.attention_output(x)
        x1 = self.conv1(x1)
        x2 = torch.add(x1, x)
        x3 = x2.permute(0, 2, 3, 1)
        x3 = self.layernorm(x3)
        x3 = x3.permute(0, 3, 1, 2)
        x3 = self.wide_focus(x3)
        x3 = torch.add(x2, x3)
        return x3

        return x

class Wide_Focus(nn.Module):
    """
    Wide-Focus module.
    """

    def __init__(self,
                 in_channels,
                 out_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same", dilation=2)
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same", dilation=3)
        self.conv4 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = F.gelu(x1)
        x1 = F.dropout(x1, 0.1)
        x2 = self.conv2(x)
        x2 = F.gelu(x2)
        x2 = F.dropout(x2, 0.1)
        x3 = self.conv3(x)
        x3 = F.gelu(x3)
        x3 = F.dropout(x3, 0.1)
        added = torch.add(x1, x2)
        added = torch.add(added, x3)
        x_out = self.conv4(added)
        x_out = F.gelu(x_out)
        x_out = F.dropout(x_out, 0.1)
        return x_out

class Block_encoder_bottleneck(nn.Module):
    def __init__(self, blk, in_channels, out_channels, att_heads, dpr):
        super().__init__()
        self.blk = blk
        if ((self.blk == "first") or (self.blk == "bottleneck")):
            self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
            self.trans = Transformer(out_channels, att_heads, dpr)
        elif ((self.blk == "second") or (self.blk == "third") or (self.blk == "fourth")):
            self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
            self.conv1 = nn.Conv2d(1, in_channels, 3, 1, padding="same")
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
            self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
            self.trans = Transformer(out_channels, att_heads, dpr)

    def forward(self, x, scale_img="none"):
        if ((self.blk == "first") or (self.blk == "bottleneck")):
            x1 = x.permute(0, 2, 3, 1)
            x1 = self.layernorm(x1)
            x1 = x1.permute(0, 3, 1, 2)
            x1 = F.relu(self.conv1(x1))
            x1 = F.relu(self.conv2(x1))
            x1 = F.dropout(x1, 0.3)
            x1 = F.max_pool2d(x1, (2, 2))
            out = self.trans(x1)
            # without skip
        elif ((self.blk == "second") or (self.blk == "third") or (self.blk == "fourth")):
            x1 = x.permute(0, 2, 3, 1)
            x1 = self.layernorm(x1)
            x1 = x1.permute(0, 3, 1, 2)
            x1 = torch.cat((F.relu(self.conv1(scale_img)), x1), axis=1)
            x1 = F.relu(self.conv2(x1))
            x1 = F.relu(self.conv3(x1))
            x1 = F.dropout(x1, 0.3)
            x1 = F.max_pool2d(x1, (2, 2))
            out = self.trans(x1)
            # with skip
        return out

class Block_decoder(nn.Module):
    def __init__(self, in_channels, out_channels, att_heads, dpr):
        super().__init__()
        self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(out_channels * 2, out_channels, 3, 1, padding="same")
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.trans = Transformer(out_channels, att_heads, dpr)

    def forward(self, x, skip):
        x1 = x.permute(0, 2, 3, 1)
        x1 = self.layernorm(x1)
        x1 = x1.permute(0, 3, 1, 2)
        x1 = self.upsample(x1)
        x1 = F.relu(self.conv1(x1))
        x1 = torch.cat((skip, x1), axis=1)
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x1 = F.dropout(x1, 0.3)
        out = self.trans(x1)
        return out

class DS_out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, 1, padding="same")
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")

    def forward(self, x):
        x1 = self.upsample(x)
        x1 = x1.permute(0, 2, 3, 1)
        x1 = self.layernorm(x1)
        x1 = x1.permute(0, 3, 1, 2)
        x1 = F.relu(self.conv1(x1))
        x1 = F.relu(self.conv2(x1))
        out = torch.sigmoid(self.conv3(x1))

        return out

class FCT(nn.Module):
    def __init__(self):
        super().__init__()

        # attention heads and filters per block
        att_heads = [2, 2, 2, 2, 2, 2, 2, 2, 2]
        filters = [8, 16, 32, 64, 128, 64, 32, 16, 8]

        # number of blocks used in the model
        blocks = len(filters)

        stochastic_depth_rate = 0.0

        # probability for each block
        dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]

        self.drp_out = 0.3

        # shape
        init_sizes = torch.ones((2, 224, 224, 1))
        init_sizes = init_sizes.permute(0, 3, 1, 2)

        # Multi-scale input
        self.scale_img = nn.AvgPool2d(2, 2)

        # model
        self.block_1 = Block_encoder_bottleneck("first", 1, filters[0], att_heads[0], dpr[0])
        self.block_2 = Block_encoder_bottleneck("second", filters[0], filters[1], att_heads[1], dpr[1])
        self.block_3 = Block_encoder_bottleneck("third", filters[1], filters[2], att_heads[2], dpr[2])
        self.block_4 = Block_encoder_bottleneck("fourth", filters[2], filters[3], att_heads[3], dpr[3])
        self.block_5 = Block_encoder_bottleneck("bottleneck", filters[3], filters[4], att_heads[4], dpr[4])
        self.block_6 = Block_decoder(filters[4], filters[5], att_heads[5], dpr[5])
        self.block_7 = Block_decoder(filters[5], filters[6], att_heads[6], dpr[6])
        self.block_8 = Block_decoder(filters[6], filters[7], att_heads[7], dpr[7])
        self.block_9 = Block_decoder(filters[7], filters[8], att_heads[8], dpr[8])

        self.ds7 = DS_out(filters[6], 4)
        self.ds8 = DS_out(filters[7], 4)
        self.ds9 = DS_out(filters[8], 4)

    def forward(self, x):
        # Multi-scale input
        scale_img_2 = self.scale_img(x)
        scale_img_3 = self.scale_img(scale_img_2)
        scale_img_4 = self.scale_img(scale_img_3)

        x = self.block_1(x)
        # print(f"Block 1 out -> {list(x.size())}")
        skip1 = x
        x = self.block_2(x, scale_img_2)
        # print(f"Block 2 out -> {list(x.size())}")
        skip2 = x
        x = self.block_3(x, scale_img_3)
        # print(f"Block 3 out -> {list(x.size())}")
        skip3 = x
        x = self.block_4(x, scale_img_4)
        # print(f"Block 4 out -> {list(x.size())}")
        skip4 = x
        x = self.block_5(x)
        # print(f"Block 5 out -> {list(x.size())}")
        x = self.block_6(x, skip4)
        # print(f"Block 6 out -> {list(x.size())}")
        x = self.block_7(x, skip3)
        # print(f"Block 7 out -> {list(x.size())}")
        skip7 = x
        x = self.block_8(x, skip2)
        # print(f"Block 8 out -> {list(x.size())}")
        skip8 = x
        x = self.block_9(x, skip1)
        # print(f"Block 9 out -> {list(x.size())}")
        skip9 = x

        out7 = self.ds7(skip7)
        # print(f"DS 7 out -> {list(out7.size())}")
        out8 = self.ds8(skip8)
        # print(f"DS 8 out -> {list(out8.size())}")
        out9 = self.ds9(skip9)
        # print(f"DS 9 out -> {list(out9.size())}")

        return out7, out8, out9

def init_weights(m):
    """
    Initialize the weights
    """
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_normal(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

wandb.login(key="xxx")
wandb.init(project="xxx")
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = FCT().to(device)
# model.apply(init_weights)
model_path = "E:\\dev\\py\\CASCADE\\ac\\model\\fct\\model20230313-17_M.pth"
model.load_state_dict(torch.load(model_path))
# %% Training
# initialize the loss function
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    total_pixels = 0

    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        pred = model(X)
        loss = loss_fn(pred[2], y)
        loss.backward()
        optimizer.step()

        correct += (pred[2].argmax(1) == y).type(torch.float).sum().item()
        total_pixels += y.numel()
        train_loss += loss.item()

        # print statistics
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)

            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    correct /= total_pixels
    train_loss /= len(dataloader)
    print(f"Train Error: \n Accuracy: {(correct):>0.6f}, Avg loss: {train_loss:>8f} \n")
    wandb.log({"epoch/train loss ": train_loss})

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    total_pixels = 0
    dice_coeff_sum1 = 0.0
    dice_coeff_sum2 = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred[2], y).item()
            correct += (pred[2].argmax(1) == y).type(torch.float).sum().item()
            total_pixels += y.numel()

            pred_label = pred[2].argmax(dim=1).cpu().numpy()
            y_label = y.cpu().numpy()
            dice_coeff = dc(pred_label, y_label)
            dice_coeff_sum1 += dice_coeff

            dice_coeff = dc((pred[2] > 0.1).cpu().numpy(), y_label)
            dice_coeff_sum2 += dice_coeff

    test_loss /= len(dataloader)
    correct /= total_pixels

    print(
        f"Test Error: \n Accuracy: {(correct):>0.6f}, Avg loss: {test_loss:>8f}, Avg dsc1: {dice_coeff_sum1 / len(dataloader):>4f}, Avg dsc2: {dice_coeff_sum2 / len(dataloader):>4f}\n")
    wandb.log({"epoch/val loss ": test_loss, "epoch/val dsc1": dice_coeff_sum1 / len(dataloader),
               "epoch/val dsc2": dice_coeff_sum2 / len(dataloader)})

for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(validation_dataloader, model, loss_fn)
print("Done!")
lee-wt commented 1 year ago

Thanks for your work! I tried the Pytorch implementation and encountered the same problem, all the predicted classes are background classes. I trained the model for 200 epochs, and tested the 180-epoch model. My pytorch version is 1.9.1. I didn't change any setting of the FCT model, not adding any metric. So why did it happen? Looking forward to your reply.

Thanos-DB commented 1 year ago

I ll have to look into it. I ll come back with an update

lee-wt commented 1 year ago

Thank you! Looking forward to the latest version.

Thanos-DB commented 1 year ago

Hi everyone, it seems that it works fine for me. One question, did you check that your masks are all ok before start training? So before creating the dataloaders plot all 4 masks (background, RV, LV, MYO) and make sure everything looks fine. I have the feeling that maybe you were using the convert_masks function as seen in the TensorFlow implementation which is wrong as in TensorFlow we have channels last and in PyTorch channels are first. In any case i have below an example that should work for you without problems. Notice that i am training using the validation set because it is smaller and i did not want to wait for the training set to finish. Also at the end you can see a prediction. The model is fairly small so change it based on your resources and needs. If you want SOTA results i would suggest using a scheduler for the learning rate plus the other techniques as seen in the paper (deep supervision, data generators etc.) Here is the code that predicts nicely all classes issue#3.txt

Here is one prediction image