Open vishwa91 opened 4 years ago
You can modify coordinates of centroids by updating slic.slic_model.clusters
to a new list where slic
is an instance of Slic
or SlicAvx2
class. The value is a list of a dict in which coordinate of position is stored in the key yx
. Note that the value of the key 'rgb' actually indicates the tuple of LAB color multiplied by 2. Also, you should assign a new list to the property, not modifying the value in place.
You can iterate only once by slic.iterate(img, 1)
.
However, there is one caveat to consider: fast-slic samples the portion of image in the row-basis. That is, when subsample_stride
is 3, only one third of image rows are used for each iteration. So, iterating only once might not utilize the entire information of the image. If that matters to you, you might consider setting subsample_stride
to 1
in the constructor of class Slic
. Or, you can grow the number of iteration up to 3
.
For more information on subsampling, please refer to my undergraduate paper uploaded in issue #7 .
If I understand correctly, these are the steps I need to follow:
Hi @vishwa91 , I have the same problem and would like to ask if you have solved it. I tried the solution mentioned by @Algy , however, the result is not as I expected. Please see the example below. I defined a few centroids (not evenly distributed), set subsample_stride=1
and run iteration only once. The centroids I set are labelled in picture 1. The slic labels are in picture 2. You can see that the segmentation is not following the defined centroids.
I also tried setting the number of iteration as 0 and the result is below.
I'm not sure if the function still runs several iterations when we set 1 or 0, or I did something wrong. @vishwa91 could you please kindly let me know if you know what the problem is? Any insight is appreciated. My code is as follows.
from fast_slic import Slic
from fast_slic.avx2 import SlicAvx2
from skimage import segmentation, color
import matplotlib.pyplot as plt
import cv2
import imutils
image = cv2.imread("fish.jpeg")
image = imutils.resize(image, width=300)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
slic = Slic(num_components=12, compactness=10, subsample_stride=1)
n_clusters = len(slic.slic_model.clusters)
# use pre-defined centroids
centers_w = [20, 70, 120, 170, 220, 270] * 2
centers_h = [40] * 6 + [80] * 6
new_clusters = slic.slic_model.clusters.copy()
for k in range(0, len(new_clusters)):
cluster = new_clusters[k]
cluster['yx'] = (centers_h[k], centers_w[k])
cluster['number'] = k
slic.slic_model.clusters = new_clusters
print('after center assignment', slic.slic_model.clusters)
# iterate only once
slic_result = slic.iterate(image, 1)
# plot results
fig, ax_arr = plt.subplots(1, 2)
ax1, ax2 = ax_arr.ravel()
# show centers in image
for cluster in new_clusters:
label = cluster['number']
center = cluster['yx']
cv2.circle(image, (int(center[1]), int(center[0])), 3, [255, 0, 255], -1)
cv2.putText(image, str(label), (int(center[1]), int(center[0])), cv2.FONT_HERSHEY_PLAIN, 2, [255, 0, 255], 2)
ax1.imshow(segmentation.mark_boundaries(image, slic_result))
ax2.imshow(slic_result)
plt.show()
@weiwenchuan I could not get @Algy code to work with my own centroids, so I wrote a partial SLIC update code below:
def slic_update(imrgb, mask, compactness=10.0):
'''
Performs one single step of SLIC to update membership
Inputs:
imrgb: RGB image
mask: Sparse sampling mask / centroids of superpixels after
sanitizing
compactness: SLIC compactness parameter
Outputs:
L: superpixel membership map
N: Total number of super pixels
'''
H, W, _ = imrgb.shape
ch, cw = np.where(mask == 1)
# Create LabXY image
[Y, X] = np.mgrid[:H, :W]
imlabxy = np.zeros((H, W, 5), dtype=np.float32)
imlabxy[:, :, :3] = cv2.cvtColor(imrgb, cv2.COLOR_RGB2Lab)
imlabxy[:, :, 3] = X
imlabxy[:, :, 4] = Y
# Reshape to a matrix
imlabxymat = imlabxy.reshape(H*W, 5).astype(np.float32)
centroids_labxy = imlabxy[ch, cw, :].astype(np.float32)
N = ch.size
S = int(np.sqrt(H*W/N))
nmembers = np.zeros(N)
dist_matrix = np.ones((H, W), dtype=np.float32)*float('inf')
L = np.ones((H, W), dtype=np.uint16)
# Inefficient, but just do it
for idx in range(N):
hmin = max(0, ch[idx] - 2*S); hmax = min(H, ch[idx] + 2*S)
wmin = max(0, cw[idx] - 2*S); wmax = min(W, cw[idx] + 2*S)
imlabxy_patch = imlabxy[hmin:hmax, wmin:wmax, :]
dist_patch_old = dist_matrix[hmin:hmax, wmin:wmax]
dist_patch = cassi_cp._get_dist_cp(centroids_labxy[idx, :], imlabxy_patch,
np.float32(compactness), S)
L_patch = L[hmin:hmax, wmin:wmax]
L_patch[dist_patch < dist_patch_old] = idx
L[hmin:hmax, wmin:wmax] = L_patch
dist_matrix[hmin:hmax, wmin:wmax] = np.minimum(dist_patch,
dist_patch_old)
return L.astype(np.uint16), N
The function cassi_cp._get_dist_cp()
was written in cython
:
import numpy as np
import cv2
from cython.parallel import parallel, prange
# Compile time optimizations
cimport numpy as np
cimport cython
# We will mostly use UINT8
DTYPE_UINT8 = np.uint8
DTYPE_UINT16 = np.uint16
DTYPE_FLOAT32 = np.float32
DTYPE_INT16 = np.int16
ctypedef np.uint8_t DTYPE_UINT8_t
ctypedef np.uint16_t DTYPE_UINT16_t
ctypedef np.float32_t DTYPE_FLOAT32_t
ctypedef np.int16_t DTYPE_INT16_t
@cython.boundscheck(False)
@cython.wraparound(False)
def _get_dist_cp(np.ndarray[DTYPE_FLOAT32_t, ndim=1] centroid_xy,
np.ndarray[DTYPE_FLOAT32_t, ndim=3] imlabxy_patch,
float compactness, int S):
'''
Function to rapidly compute distance from a centroid over a patch
'''
# Declare all variables ahead
cdef int H
cdef int W
cdef int h
cdef int w
cdef float C
H = imlabxy_patch.shape[0]
W = imlabxy_patch.shape[1]
C = compactness/S
# Create new matrix to store distances
dist = np.zeros((H, W), dtype=DTYPE_FLOAT32)
# Creating a data view will make all operations much faster
cdef DTYPE_FLOAT32_t[:, :] dist_view = dist
# Now run through all variables
for h in prange(H, nogil=True):
for w in range(W):
# Unroll the whole computation
dist_view[h, w] = ((imlabxy_patch[h, w, 0] - centroid_xy[0])**2 + \
(imlabxy_patch[h, w, 1] - centroid_xy[1])**2 + \
(imlabxy_patch[h, w, 2] - centroid_xy[2])**2 + \
C*(imlabxy_patch[h, w, 3] - centroid_xy[3])**2 + \
C*(imlabxy_patch[h, w, 4] - centroid_xy[4])**2)
return dist
The relevant setup.py
file for compiling:
# Compilation tools
from distutils.core import Extension, setup
from Cython.Build import cythonize
# Scientific computing
import numpy as np
ext_modules = [
Extension(
"cassi_cp",
["cassi_cp.pyx"],
extra_compile_args=['-fopenmp', '-march=native', '-O3', '-ffast-math'],
extra_link_args=['-fopenmp'],
include_dirs=[np.get_include()]
)
]
setup(
name='cassi_cp',
ext_modules=cythonize(ext_modules)
)
Hope that helps!
Hi @vishwa91 , thanks for the prompt reply! So do you mean you wrote your own code and didn't use this fast-slic method? I didn't use cpython before but I'll definitely try your code. I also tried to write my own code (in Python) for single-step of SLIC but that code runs too slow, therefore I tried this fast-slic code. Did you test the efficiency (running time) of your own update function?
@weiwenchuan -- yes, I used my own code. To compile the cpython code, you may need to run:
python setup.py build_ext --inplace
Regarding efficiency -- no I did not profile my code, but am hoping its fairly fast as the costliest step was done in cpython (practically C).
Hope that helps!
hi @vishwa91 thank you. I just tried it and it works (although the running time is higher than this fast-slic, the segmentation result is really clear). Thank you for sharing your code!
Would it be possible to get a function handle that performs one single step of SLIC update?
Specifically: Given the RGB (or LAB) image, and the centroids as input, the function should output the label map. This will be very useful if the centroids are constrained to be in a certain structure.
Thank you!