Closed xiehou-design closed 1 year ago
Hi xiehou-design,
thank you for taking an interest to our work! Before starting let me once again mention that the paper is based on the Tensorflow implementation. Having said that the Pytorch implementation should also work fine for you. Do you have any metrics to get a better feeling of what happens? What is your Dice (or whatever metrics you are using) for the FCT and the vanilla UNet? Is the vanilla Unet pretrained? Is your dataset freely available or a private one? What is the size of both models? Can you force FCT to overfit your dataset (you might consider using a smaller set)? Also, change the num_heads=1 to self.num_heads, it should help (you can do ablation studies to find what works best for you. You can start with 2s all over) Is your vanilla Unet also multi-input multi-output? Does the loss for the FCT reduce during training? Lastly, how many segmentation masks is your ground truth? It seems to me that you need to change your number of classes to 1.
BR Thanos
Thank your replay.
My semantic segmentation experiment uses pixel precision, recall and f1 as the metrics. I use crossentropy loss, the loss not decrease in FCT, but it normal in Unet. FCT and Unet are not pretrained.
My experiment dataset follow the cell dataset, it very small and available.
I am sure that FCT has no overfit , and even it believes that all targets are backgrounds.
I change the num_heads=1
to num_heads=self.num_heads
, but fct still does not work.
FCT and Unet's input size is 224.
my experiment code in google drive
Did you try with the number of classes i suggested? Also, and more importantly your UNet works because you use no activation on the last layer and you get logits which is what your loss function expects. This is not the case with FCT. You should correct this one.
Yes, I try it. I use the classes=1
, and replace the crossentropy to BCEloss. But it still not work.
Also, I try to replace DsOut module's the last convolution out = F.sigmoid(self.conv3(x1))
to out = self.conv3(x1)
, use the classes=2
, and crossentropy. But it still not work.
I try to find the problem more deeply. I think there may be a problem with the model's attention calculation, but I compare the version of TensorFlow code and find no obvious error.
I tried running your code and i get a lot of errors. Firstly it was mismatching dimensions. When i fixed this one i got errors concerning dtypes (float, long, etc). This happens regardless of the model because i tried both of them. This is my torch version: 1.13.1. I can tell you it is not the model because i have tested it and it produces nice segmentation masks (binary or not). I have also tested it inside frameworks from other people and it also works.
I am sorry, mu torch version is: 1.8.1. I find some problems in this version, like multi-head Attention
inputs query, key and value vector's shape. Now, I replace the original torch version and use with version 1.13.1.
Thank you replay, I don't doubt the novelty of the FCT. I will modify my code to experiment.
I encountered the same problem, all the predicted classes are background classes, avg dsc is equal to 0. And my pytorch version is 1.13.0.
I encountered the same problem, all the predicted classes are background classes, avg dsc is equal to 0. And my pytorch version is 1.13.0.
Also, me too. And I do not find the problem int the FCT model.
Hi Geekiter,
do you have any reproducible code?
I didn't change the model, I only added avg dsc calculation
import datetime
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
from torch.functional import F
# os.chdir("/home/thanos/code")
import wandb
from utils import *
# %% Get data
batch_size = 1
epochs = 10
# model+datetime
model_path = "./model.pth"
learning_rate = 0.0001
# ---- ACDC
# training
# ---- ACDC
acdc_data, _, _ = get_acdc(acdc_data_train)
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
train_dataloader = DataLoader(acdc_data, batch_size=batch_size)
# validation
acdc_data, _, _ = get_acdc(acdc_data_validation)
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
validation_dataloader = DataLoader(acdc_data, batch_size=batch_size)
# testing
acdc_data, _, _ = get_acdc(acdc_data_test)
acdc_data[1] = convert_masks(acdc_data[1])
acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels
acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels
acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors
acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors
acdc_data = TensorDataset(acdc_data[0], acdc_data[1])
test_dataloader = DataLoader(acdc_data, batch_size=batch_size)
# %% ######################################################################################
# create model
class Attention(nn.Module):
def __init__(self,
channels,
num_heads,
proj_drop=0.0,
kernel_size=3,
stride_kv=1,
stride_q=1,
padding_kv="same",
padding_q="same",
attention_bias=True
):
super().__init__()
self.stride_kv = stride_kv
self.stride_q = stride_q
self.num_heads = num_heads
self.proj_drop = proj_drop
self.conv_q = nn.Conv2d(channels, channels, kernel_size, stride_q, padding_q, bias=attention_bias,
groups=channels)
self.layernorm_q = nn.LayerNorm(channels, eps=1e-5)
self.conv_k = nn.Conv2d(channels, channels, kernel_size, stride_kv, stride_kv, bias=attention_bias,
groups=channels)
self.layernorm_k = nn.LayerNorm(channels, eps=1e-5)
self.conv_v = nn.Conv2d(channels, channels, kernel_size, stride_kv, stride_kv, bias=attention_bias,
groups=channels)
self.layernorm_v = nn.LayerNorm(channels, eps=1e-5)
self.attention = nn.MultiheadAttention(embed_dim=channels,
bias=attention_bias,
batch_first=True,
# dropout = 0.0,
num_heads=1) # num_heads=self.num_heads)
def _build_projection(self, x, qkv):
if qkv == "q":
x1 = F.relu(self.conv_q(x))
x1 = x1.permute(0, 2, 3, 1)
x1 = self.layernorm_q(x1)
proj = x1.permute(0, 3, 1, 2)
elif qkv == "k":
x1 = F.relu(self.conv_k(x))
x1 = x1.permute(0, 2, 3, 1)
x1 = self.layernorm_k(x1)
proj = x1.permute(0, 3, 1, 2)
elif qkv == "v":
x1 = F.relu(self.conv_v(x))
x1 = x1.permute(0, 2, 3, 1)
x1 = self.layernorm_v(x1)
proj = x1.permute(0, 3, 1, 2)
return proj
def forward_conv(self, x):
q = self._build_projection(x, "q")
k = self._build_projection(x, "k")
v = self._build_projection(x, "v")
return q, k, v
def forward(self, x):
q, k, v = self.forward_conv(x)
q = q.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
k = k.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
v = v.view(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
q = q.permute(0, 2, 1)
k = k.permute(0, 2, 1)
v = v.permute(0, 2, 1)
x1 = self.attention(query=q, value=v, key=k, need_weights=False)
x1 = x1[0].permute(0, 2, 1)
x1 = x1.view(x1.shape[0], x1.shape[1], np.sqrt(x1.shape[2]).astype(int), np.sqrt(x1.shape[2]).astype(int))
x1 = F.dropout(x1, self.proj_drop)
return x1
class Transformer(nn.Module):
def __init__(self,
# in_channels,
out_channels,
num_heads,
dpr,
proj_drop=0.0,
attention_bias=True,
padding_q="same",
padding_kv="same",
stride_kv=1,
stride_q=1):
super().__init__()
self.attention_output = Attention(channels=out_channels,
num_heads=num_heads,
proj_drop=proj_drop,
padding_q=padding_q,
padding_kv=padding_kv,
stride_kv=stride_kv,
stride_q=stride_q,
attention_bias=attention_bias,
)
self.conv1 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
self.layernorm = nn.LayerNorm(self.conv1.out_channels, eps=1e-5)
self.wide_focus = Wide_Focus(out_channels, out_channels)
def forward(self, x):
x1 = self.attention_output(x)
x1 = self.conv1(x1)
x2 = torch.add(x1, x)
x3 = x2.permute(0, 2, 3, 1)
x3 = self.layernorm(x3)
x3 = x3.permute(0, 3, 1, 2)
x3 = self.wide_focus(x3)
x3 = torch.add(x2, x3)
return x3
return x
class Wide_Focus(nn.Module):
"""
Wide-Focus module.
"""
def __init__(self,
in_channels,
out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same", dilation=2)
self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same", dilation=3)
self.conv4 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
def forward(self, x):
x1 = self.conv1(x)
x1 = F.gelu(x1)
x1 = F.dropout(x1, 0.1)
x2 = self.conv2(x)
x2 = F.gelu(x2)
x2 = F.dropout(x2, 0.1)
x3 = self.conv3(x)
x3 = F.gelu(x3)
x3 = F.dropout(x3, 0.1)
added = torch.add(x1, x2)
added = torch.add(added, x3)
x_out = self.conv4(added)
x_out = F.gelu(x_out)
x_out = F.dropout(x_out, 0.1)
return x_out
class Block_encoder_bottleneck(nn.Module):
def __init__(self, blk, in_channels, out_channels, att_heads, dpr):
super().__init__()
self.blk = blk
if ((self.blk == "first") or (self.blk == "bottleneck")):
self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
self.trans = Transformer(out_channels, att_heads, dpr)
elif ((self.blk == "second") or (self.blk == "third") or (self.blk == "fourth")):
self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
self.conv1 = nn.Conv2d(1, in_channels, 3, 1, padding="same")
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
self.trans = Transformer(out_channels, att_heads, dpr)
def forward(self, x, scale_img="none"):
if ((self.blk == "first") or (self.blk == "bottleneck")):
x1 = x.permute(0, 2, 3, 1)
x1 = self.layernorm(x1)
x1 = x1.permute(0, 3, 1, 2)
x1 = F.relu(self.conv1(x1))
x1 = F.relu(self.conv2(x1))
x1 = F.dropout(x1, 0.3)
x1 = F.max_pool2d(x1, (2, 2))
out = self.trans(x1)
# without skip
elif ((self.blk == "second") or (self.blk == "third") or (self.blk == "fourth")):
x1 = x.permute(0, 2, 3, 1)
x1 = self.layernorm(x1)
x1 = x1.permute(0, 3, 1, 2)
x1 = torch.cat((F.relu(self.conv1(scale_img)), x1), axis=1)
x1 = F.relu(self.conv2(x1))
x1 = F.relu(self.conv3(x1))
x1 = F.dropout(x1, 0.3)
x1 = F.max_pool2d(x1, (2, 2))
out = self.trans(x1)
# with skip
return out
class Block_decoder(nn.Module):
def __init__(self, in_channels, out_channels, att_heads, dpr):
super().__init__()
self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
self.upsample = nn.Upsample(scale_factor=2)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
self.conv2 = nn.Conv2d(out_channels * 2, out_channels, 3, 1, padding="same")
self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
self.trans = Transformer(out_channels, att_heads, dpr)
def forward(self, x, skip):
x1 = x.permute(0, 2, 3, 1)
x1 = self.layernorm(x1)
x1 = x1.permute(0, 3, 1, 2)
x1 = self.upsample(x1)
x1 = F.relu(self.conv1(x1))
x1 = torch.cat((skip, x1), axis=1)
x1 = F.relu(self.conv2(x1))
x1 = F.relu(self.conv3(x1))
x1 = F.dropout(x1, 0.3)
out = self.trans(x1)
return out
class DS_out(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2)
self.layernorm = nn.LayerNorm(in_channels, eps=1e-5)
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, padding="same")
self.conv2 = nn.Conv2d(in_channels, in_channels, 3, 1, padding="same")
self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
def forward(self, x):
x1 = self.upsample(x)
x1 = x1.permute(0, 2, 3, 1)
x1 = self.layernorm(x1)
x1 = x1.permute(0, 3, 1, 2)
x1 = F.relu(self.conv1(x1))
x1 = F.relu(self.conv2(x1))
out = torch.sigmoid(self.conv3(x1))
return out
class FCT(nn.Module):
def __init__(self):
super().__init__()
# attention heads and filters per block
att_heads = [2, 2, 2, 2, 2, 2, 2, 2, 2]
filters = [8, 16, 32, 64, 128, 64, 32, 16, 8]
# number of blocks used in the model
blocks = len(filters)
stochastic_depth_rate = 0.0
# probability for each block
dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)]
self.drp_out = 0.3
# shape
init_sizes = torch.ones((2, 224, 224, 1))
init_sizes = init_sizes.permute(0, 3, 1, 2)
# Multi-scale input
self.scale_img = nn.AvgPool2d(2, 2)
# model
self.block_1 = Block_encoder_bottleneck("first", 1, filters[0], att_heads[0], dpr[0])
self.block_2 = Block_encoder_bottleneck("second", filters[0], filters[1], att_heads[1], dpr[1])
self.block_3 = Block_encoder_bottleneck("third", filters[1], filters[2], att_heads[2], dpr[2])
self.block_4 = Block_encoder_bottleneck("fourth", filters[2], filters[3], att_heads[3], dpr[3])
self.block_5 = Block_encoder_bottleneck("bottleneck", filters[3], filters[4], att_heads[4], dpr[4])
self.block_6 = Block_decoder(filters[4], filters[5], att_heads[5], dpr[5])
self.block_7 = Block_decoder(filters[5], filters[6], att_heads[6], dpr[6])
self.block_8 = Block_decoder(filters[6], filters[7], att_heads[7], dpr[7])
self.block_9 = Block_decoder(filters[7], filters[8], att_heads[8], dpr[8])
self.ds7 = DS_out(filters[6], 4)
self.ds8 = DS_out(filters[7], 4)
self.ds9 = DS_out(filters[8], 4)
def forward(self, x):
# Multi-scale input
scale_img_2 = self.scale_img(x)
scale_img_3 = self.scale_img(scale_img_2)
scale_img_4 = self.scale_img(scale_img_3)
x = self.block_1(x)
# print(f"Block 1 out -> {list(x.size())}")
skip1 = x
x = self.block_2(x, scale_img_2)
# print(f"Block 2 out -> {list(x.size())}")
skip2 = x
x = self.block_3(x, scale_img_3)
# print(f"Block 3 out -> {list(x.size())}")
skip3 = x
x = self.block_4(x, scale_img_4)
# print(f"Block 4 out -> {list(x.size())}")
skip4 = x
x = self.block_5(x)
# print(f"Block 5 out -> {list(x.size())}")
x = self.block_6(x, skip4)
# print(f"Block 6 out -> {list(x.size())}")
x = self.block_7(x, skip3)
# print(f"Block 7 out -> {list(x.size())}")
skip7 = x
x = self.block_8(x, skip2)
# print(f"Block 8 out -> {list(x.size())}")
skip8 = x
x = self.block_9(x, skip1)
# print(f"Block 9 out -> {list(x.size())}")
skip9 = x
out7 = self.ds7(skip7)
# print(f"DS 7 out -> {list(out7.size())}")
out8 = self.ds8(skip8)
# print(f"DS 8 out -> {list(out8.size())}")
out9 = self.ds9(skip9)
# print(f"DS 9 out -> {list(out9.size())}")
return out7, out8, out9
def init_weights(m):
"""
Initialize the weights
"""
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
wandb.login(key="xxx")
wandb.init(project="xxx")
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = FCT().to(device)
# model.apply(init_weights)
model_path = "E:\\dev\\py\\CASCADE\\ac\\model\\fct\\model20230313-17_M.pth"
model.load_state_dict(torch.load(model_path))
# %% Training
# initialize the loss function
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
train_loss, correct = 0, 0
total_pixels = 0
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
pred = model(X)
loss = loss_fn(pred[2], y)
loss.backward()
optimizer.step()
correct += (pred[2].argmax(1) == y).type(torch.float).sum().item()
total_pixels += y.numel()
train_loss += loss.item()
# print statistics
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
correct /= total_pixels
train_loss /= len(dataloader)
print(f"Train Error: \n Accuracy: {(correct):>0.6f}, Avg loss: {train_loss:>8f} \n")
wandb.log({"epoch/train loss ": train_loss})
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
total_pixels = 0
dice_coeff_sum1 = 0.0
dice_coeff_sum2 = 0.0
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
pred = model(X)
test_loss += loss_fn(pred[2], y).item()
correct += (pred[2].argmax(1) == y).type(torch.float).sum().item()
total_pixels += y.numel()
pred_label = pred[2].argmax(dim=1).cpu().numpy()
y_label = y.cpu().numpy()
dice_coeff = dc(pred_label, y_label)
dice_coeff_sum1 += dice_coeff
dice_coeff = dc((pred[2] > 0.1).cpu().numpy(), y_label)
dice_coeff_sum2 += dice_coeff
test_loss /= len(dataloader)
correct /= total_pixels
print(
f"Test Error: \n Accuracy: {(correct):>0.6f}, Avg loss: {test_loss:>8f}, Avg dsc1: {dice_coeff_sum1 / len(dataloader):>4f}, Avg dsc2: {dice_coeff_sum2 / len(dataloader):>4f}\n")
wandb.log({"epoch/val loss ": test_loss, "epoch/val dsc1": dice_coeff_sum1 / len(dataloader),
"epoch/val dsc2": dice_coeff_sum2 / len(dataloader)})
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(validation_dataloader, model, loss_fn)
print("Done!")
Thanks for your work! I tried the Pytorch implementation and encountered the same problem, all the predicted classes are background classes. I trained the model for 200 epochs, and tested the 180-epoch model. My pytorch version is 1.9.1. I didn't change any setting of the FCT model, not adding any metric. So why did it happen? Looking forward to your reply.
I ll have to look into it. I ll come back with an update
Thank you! Looking forward to the latest version.
Hi everyone,
it seems that it works fine for me. One question, did you check that your masks are all ok before start training? So before creating the dataloaders plot all 4 masks (background, RV, LV, MYO) and make sure everything looks fine. I have the feeling that maybe you were using the convert_masks
function as seen in the TensorFlow implementation which is wrong as in TensorFlow we have channels last and in PyTorch channels are first. In any case i have below an example that should work for you without problems. Notice that i am training using the validation set because it is smaller and i did not want to wait for the training set to finish. Also at the end you can see a prediction. The model is fairly small so change it based on your resources and needs. If you want SOTA results i would suggest using a scheduler for the learning rate plus the other techniques as seen in the paper (deep supervision, data generators etc.)
Here is the code that predicts nicely all classes issue#3.txt
Here is one prediction
Thanks your work! But I find some problems in pytorch's model code. I used your code to do a bineary semantic segmentation experiment, and the effect was very bad. Even if all the predictions were background classes, the model would not converge at all. I ensured that the data processing in my experiment was effective, and I replaced it with the simplest unet model to ensure the accuracy of my data processing process. But the unet model get correct result.
here are my FCT model: