Closed elk-april closed 4 years ago
I modified the structure of your code(divide to train/test) and delete some codes,like below: train.py
for x, y in tqdm(train_loader):
x, y = x.to(device), y.to(device)
input_images(x, y, i, n_iter, k)
opt.zero_grad()
y_pred = model_test(x)
lossT = calc_loss(y_pred, y) # Dice_loss Used
train_loss += lossT.item() * x.size(0)
lossT.backward()
opt.step()
x_size = lossT.item() * x.size(0)
k = 2
model_test.eval()
torch.no_grad()
for x1, y1 in tqdm(valid_loader):
x1, y1 = x1.to(device), y1.to(device)
y_pred1 = model_test(x1)
lossL = calc_loss(y_pred1, y1) # Dice_loss Used
valid_loss += lossL.item() * x1.size(0)
x_size1 = lossL.item() * x1.size(0)
# print loss
train_loss = train_loss / len(train_idx)
valid_loss = valid_loss / len(valid_idx)
if (i + 1) % 1 == 0:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss))
torch.save(model_test.state_dict(),read_model_path+str(i+1)+".pth")
loss_num.append(str(valid_loss)[0:5])
print(read_model_path+str(i+1)+".pth saved")
test.py
`for _ in tqdm(read_test_folder):
# print(x_sort_test[i])
im = Image.open(x_sort_test[i])
im1 = im
im_n = np.array(im1)
im_n_flat = im_n.reshape(-1, 1)
for j in range(im_n_flat.shape[0]):
if im_n_flat[j] != 0:
im_n_flat[j] = 255
s = data_transform(im)
pred = model_test(s.unsqueeze(0).cuda()).cpu()
pred = F.sigmoid(pred)
pred = pred.detach().numpy()
pred = pred[0][0]
img4 = (pred - pred.min()) * ((255 - 0) / (pred.max() - pred.min())) + 0
img = img4.astype(np.uint8)
threshold_predictions_v(img, args.thr)
cv2.imwrite(res + x_sort_test[i].split('/')[-1], img)
i = i + 1
imou = func.compute_mIoU(res, test_folderL)
print("checkpoints {} : miou {}".format(args.num, imou))`
Hello, So what is the size of your dataset and whats the segmentation for? And if it is very less than the SOTA , then there should be a problem.
Thanks for your reply! The size of my dataset is (1,384,416), Then I chanslate them to (3,384,416) by this code:
def expand(i_dir,o_dir):
for name in os.listdir(i_dir):
print(name)
image_path = os.path.join(i_dir,name)
img = cv2.imread(image_path,-1)
image = np.expand_dims(img, axis=2)
image = np.concatenate((image, image, image), axis=-1)
base_name = name.split('.')[0]
#print(base_name)
cv2.imwrite(o_dir+base_name+'.png',image)
And the dataset is segmentation for lesion of the brain. I will show some examples for you: Orignal img:
Chanslated img:
Label :
My result:
All the images are black?
yes Orignal img: dyepe=uint16 ,like below label : dyepe=uint8 and,beckgroud =0,lesion=1 ,like below
And I decide to rewrite the dataloer like this :
def __getitem__(self, i):
i1 = cv2.imread(self.images_dir + self.images[i],-1)
l1 = cv2.imread(self.labels_dir + self.labels[i],-1)
seed = np.random.randint(0, 2 ** 32) # make a seed with numpy generator
# apply this seed to img tranfsorms
random.seed(seed)
torch.manual_seed(seed)
# apply this seed to target/label tranfsorms
random.seed(seed)
torch.manual_seed(seed)
img = torch.Tensor(i1/1.0)
label = torch.Tensor(l1/1.0)
img = img.unsqueeze(0)
label = label.unsqueeze(0)
#print(img.max())
#print()
return img, label
If you need the data,I can send a pail of data to you by email.
That's fine. Did you check how the images are bring trained? Is the attention given to the right place and the activation at correct places?
Sorry,i can't understand, would you like to explain what you said......
I have tried to remove the
torchvision.transforms.CenterCrop(96),
But it does not work.
And when I check the code,I found these :(in the pytorch_run.py line482)
for j in range(im_n_flat.shape[0]):
if im_n_flat[j] != 0:
im_n_flat[j] = 255
The labels of my dataset is between 0 to 1, maybe I can set " im_n_flat[j] = 1"?
Thank you very much! I may find my wrong.I ues new dataloder.Now the "lesion" iou reaches 0.70(UNet) Now I can continue to try another model. Thanks for your reply!
hello, I use my own data to train the model, however, there is little that can be done to improve the iou by tuning lr (currently the highest accuracy is 60%,lr=0.1/0.05/0.01, weight_decay=1e-8). Is there any way to improve the performance? looking forward to your reply!