from scipy import ndimage
#3. take the last 3 channels of the first image of the batch
mask_img = masks[0][:,:,0]
# Although not completely clear from documentatiotn but distance_transform_edt
# computes the distance from non-zero (i.e. non-background) points to
# the nearest zero (i.e. background) point.
weight_distance = ndimage.distance_transform_edt(mask_img == 0)
#Since we would like to combine with other weights
#we set a scale 0-100 top the weight.
plt.figure()
plt.imshow(weight_distance)
plt.colorbar()
plt.show()
# was done to avoid sum area and weight contribution ... bad results. (after 13 epochs .56)
weight_distance = np.where(weight_distance <= 3, weight_distance.max(),weight_distance)
weight_distance =10*np.exp(-weight_distance/30)
plt.figure()
plt.imshow(weight_distance)
plt.colorbar()
plt.show()
#print(res)