tjpulkl / CDGNet

49 stars 5 forks source link

example of inference #1

Closed xddlj closed 2 years ago

xddlj commented 2 years ago

Can you provide a simply script to get the visualization result? thanks you!

tjpulkl commented 2 years ago

def mask_to_onehot(mask, num_classes): """ Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one hot encoding vector

"""
_mask = [mask == i for i in range(num_classes)]
return np.array(_mask).astype(np.uint8)

def generate_hw_gt( seg_onehot, class_num = 20 ): c,h,w = seg_onehot.shape

h distribution ground truth

seg_onehot = torch.from_numpy( seg_onehot )
hgt = torch.zeros((class_num,h))
hgt=( torch.sum( seg_onehot, dim=2 ) ).float()
max = torch.max(hgt,dim=1)[0]         #c,1
max = max.unsqueeze(1)  
hgt = hgt / ( max  + 1e-5 )   
# w distribution gound truth
wgt = torch.zeros((class_num,w))
wgt=( torch.sum(seg_onehot, dim=1 ) ).float()
max = torch.max(wgt,dim=1)[0]         #c,1
max = max.unsqueeze(1)    
wgt = wgt / (  max + 1e-5 )
#===========================================================
hwgt = ( torch.sum( seg_onehot, dim=[1,2] )).float()
sum = torch.sum( hwgt )
hwgt = hwgt / ( sum + 1e-5 )
#====================================================================
return hgt, wgt, hwgt

def VisualizeHWD( gt ):

h,w = gt.shape
onehot_gt = mask_to_onehot(gt,20)
hgt, wgt, hwgt = generate_hw_gt( onehot_gt,20 )
hgt = hgt.unsqueeze(0)                                      #1,c,h
hgt = hgt.transpose(1,2).unsqueeze(1)
hgt = F.interpolate(hgt, size=(h, w), mode='bilinear')   
hgt = hgt.squeeze(0).squeeze(0)
wgt = wgt.unsqueeze(0)
wgt = wgt.unsqueeze(1)
wgt = F.interpolate(wgt, size=(h, w), mode='bilinear')
wgt = wgt.squeeze(0).squeeze(0)   

plt.style.use('classic')     
plt.figure()    
plt.subplot(2,2,1)
plt.imshow( gt )
plt.axis('off')
plt.xticks([]) 
plt.yticks([])    
plt.subplot(2,2,2)
plt.imshow(hgt)
plt.axis('off')
plt.xticks([]) 
plt.yticks([])
plt.subplot(2,2,3)    
plt.imshow(wgt)  
plt.axis('off')
plt.show()

pilImg = Image.open("D:/testPy/Basic Code/data/imgVal/imag/305175_425003.png") npImg = np.asarray( pilImg ) VisualizeHWD( npImg )