Algy / fast-slic

20x Real-time superpixel SLIC Implementation with CPU
MIT License
263 stars 34 forks source link

Function for single step of SLIC #8

Open vishwa91 opened 4 years ago

vishwa91 commented 4 years ago

Would it be possible to get a function handle that performs one single step of SLIC update?

Specifically: Given the RGB (or LAB) image, and the centroids as input, the function should output the label map. This will be very useful if the centroids are constrained to be in a certain structure.

Thank you!

Algy commented 4 years ago

You can modify coordinates of centroids by updating slic.slic_model.clusters to a new list where slic is an instance of Slic or SlicAvx2 class. The value is a list of a dict in which coordinate of position is stored in the key yx. Note that the value of the key 'rgb' actually indicates the tuple of LAB color multiplied by 2. Also, you should assign a new list to the property, not modifying the value in place.

Algy commented 4 years ago

You can iterate only once by slic.iterate(img, 1). However, there is one caveat to consider: fast-slic samples the portion of image in the row-basis. That is, when subsample_stride is 3, only one third of image rows are used for each iteration. So, iterating only once might not utilize the entire information of the image. If that matters to you, you might consider setting subsample_stride to 1 in the constructor of class Slic. Or, you can grow the number of iteration up to 3. For more information on subsampling, please refer to my undergraduate paper uploaded in issue #7 .

vishwa91 commented 4 years ago

If I understand correctly, these are the steps I need to follow:

  1. Assign new centroids to slic.slic_mode.clusters (new list, not modified)
  2. Iterate thrice (if stride==3) with the same slic model as above.
weiwenchuan commented 3 years ago

Hi @vishwa91 , I have the same problem and would like to ask if you have solved it. I tried the solution mentioned by @Algy , however, the result is not as I expected. Please see the example below. I defined a few centroids (not evenly distributed), set subsample_stride=1 and run iteration only once. The centroids I set are labelled in picture 1. The slic labels are in picture 2. You can see that the segmentation is not following the defined centroids. Screen Shot 2021-09-18 at 1 26 11 PM I also tried setting the number of iteration as 0 and the result is below. Screen Shot 2021-09-18 at 1 26 48 PM I'm not sure if the function still runs several iterations when we set 1 or 0, or I did something wrong. @vishwa91 could you please kindly let me know if you know what the problem is? Any insight is appreciated. My code is as follows.

from fast_slic import Slic
from fast_slic.avx2 import SlicAvx2
from skimage import segmentation, color
import matplotlib.pyplot as plt
import cv2
import imutils

image = cv2.imread("fish.jpeg")
image = imutils.resize(image, width=300)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

slic = Slic(num_components=12, compactness=10, subsample_stride=1)
n_clusters = len(slic.slic_model.clusters)

# use pre-defined centroids
centers_w = [20, 70, 120, 170, 220, 270] * 2
centers_h = [40] * 6 + [80] * 6
new_clusters = slic.slic_model.clusters.copy()
for k in range(0, len(new_clusters)):
    cluster = new_clusters[k]
    cluster['yx'] = (centers_h[k], centers_w[k])
    cluster['number'] = k
slic.slic_model.clusters = new_clusters
print('after center assignment', slic.slic_model.clusters)

# iterate only once
slic_result = slic.iterate(image, 1)

# plot results
fig, ax_arr = plt.subplots(1, 2)
ax1, ax2 = ax_arr.ravel()

# show centers in image
for cluster in new_clusters:
    label = cluster['number']
    center = cluster['yx']
    cv2.circle(image, (int(center[1]), int(center[0])), 3, [255, 0, 255], -1)
    cv2.putText(image, str(label), (int(center[1]), int(center[0])), cv2.FONT_HERSHEY_PLAIN, 2, [255, 0, 255], 2)
ax1.imshow(segmentation.mark_boundaries(image, slic_result))

ax2.imshow(slic_result)
plt.show()
vishwa91 commented 3 years ago

@weiwenchuan I could not get @Algy code to work with my own centroids, so I wrote a partial SLIC update code below:

def slic_update(imrgb, mask, compactness=10.0):
    '''
        Performs one single step of SLIC to update membership

        Inputs:
            imrgb: RGB image
            mask: Sparse sampling mask / centroids of superpixels after
                sanitizing
            compactness: SLIC compactness parameter

        Outputs:
            L: superpixel membership map
            N: Total number of super pixels
    '''

    H, W, _ = imrgb.shape
    ch, cw = np.where(mask == 1)

    # Create LabXY image
    [Y, X] = np.mgrid[:H, :W]

    imlabxy = np.zeros((H, W, 5), dtype=np.float32)

    imlabxy[:, :, :3] = cv2.cvtColor(imrgb, cv2.COLOR_RGB2Lab)
    imlabxy[:, :, 3] = X
    imlabxy[:, :, 4] = Y

    # Reshape to a matrix
    imlabxymat = imlabxy.reshape(H*W, 5).astype(np.float32)

    centroids_labxy = imlabxy[ch, cw, :].astype(np.float32)
    N = ch.size
    S = int(np.sqrt(H*W/N))

    nmembers = np.zeros(N)

    dist_matrix = np.ones((H, W), dtype=np.float32)*float('inf')
    L = np.ones((H, W), dtype=np.uint16)

    # Inefficient, but just do it
    for idx in range(N):
        hmin = max(0, ch[idx] - 2*S); hmax = min(H, ch[idx] + 2*S)
        wmin = max(0, cw[idx] - 2*S); wmax = min(W, cw[idx] + 2*S)

        imlabxy_patch = imlabxy[hmin:hmax, wmin:wmax, :]
        dist_patch_old = dist_matrix[hmin:hmax, wmin:wmax]
        dist_patch = cassi_cp._get_dist_cp(centroids_labxy[idx, :], imlabxy_patch,
                                           np.float32(compactness), S)

        L_patch = L[hmin:hmax, wmin:wmax]
        L_patch[dist_patch < dist_patch_old] = idx

        L[hmin:hmax, wmin:wmax] = L_patch

        dist_matrix[hmin:hmax, wmin:wmax] = np.minimum(dist_patch,
                                                       dist_patch_old)

    return L.astype(np.uint16), N

The function cassi_cp._get_dist_cp() was written in cython:

import numpy as np
import cv2
from cython.parallel import parallel, prange

# Compile time optimizations
cimport numpy as np
cimport cython

# We will mostly use UINT8
DTYPE_UINT8 = np.uint8
DTYPE_UINT16 = np.uint16
DTYPE_FLOAT32 = np.float32
DTYPE_INT16 = np.int16

ctypedef np.uint8_t DTYPE_UINT8_t
ctypedef np.uint16_t DTYPE_UINT16_t
ctypedef np.float32_t DTYPE_FLOAT32_t
ctypedef np.int16_t DTYPE_INT16_t

@cython.boundscheck(False)
@cython.wraparound(False)
def _get_dist_cp(np.ndarray[DTYPE_FLOAT32_t, ndim=1] centroid_xy,
                 np.ndarray[DTYPE_FLOAT32_t, ndim=3] imlabxy_patch,
                 float compactness, int S):
    '''
        Function to rapidly compute distance from a centroid over a patch
    '''
    # Declare all variables ahead
    cdef int H
    cdef int W
    cdef int h
    cdef int w
    cdef float C

    H = imlabxy_patch.shape[0]
    W = imlabxy_patch.shape[1]
    C = compactness/S

    # Create new matrix to store distances
    dist = np.zeros((H, W), dtype=DTYPE_FLOAT32)

    # Creating a data view will make all operations much faster
    cdef DTYPE_FLOAT32_t[:, :] dist_view = dist

    # Now run through all variables
    for h in prange(H, nogil=True):
        for w in range(W):
            # Unroll the whole computation
            dist_view[h, w] = ((imlabxy_patch[h, w, 0] - centroid_xy[0])**2 + \
                               (imlabxy_patch[h, w, 1] - centroid_xy[1])**2 + \
                               (imlabxy_patch[h, w, 2] - centroid_xy[2])**2 + \
                             C*(imlabxy_patch[h, w, 3] - centroid_xy[3])**2 + \
                             C*(imlabxy_patch[h, w, 4] - centroid_xy[4])**2)

    return dist

The relevant setup.py file for compiling:

# Compilation tools
from distutils.core import Extension, setup
from Cython.Build import cythonize

# Scientific computing
import numpy as np

ext_modules = [
    Extension(
        "cassi_cp",
        ["cassi_cp.pyx"],
        extra_compile_args=['-fopenmp', '-march=native', '-O3', '-ffast-math'],
        extra_link_args=['-fopenmp'],
        include_dirs=[np.get_include()]
    )
]

setup(
    name='cassi_cp',
    ext_modules=cythonize(ext_modules)
)

Hope that helps!

weiwenchuan commented 3 years ago

Hi @vishwa91 , thanks for the prompt reply! So do you mean you wrote your own code and didn't use this fast-slic method? I didn't use cpython before but I'll definitely try your code. I also tried to write my own code (in Python) for single-step of SLIC but that code runs too slow, therefore I tried this fast-slic code. Did you test the efficiency (running time) of your own update function?

vishwa91 commented 3 years ago

@weiwenchuan -- yes, I used my own code. To compile the cpython code, you may need to run:

python setup.py build_ext --inplace

Regarding efficiency -- no I did not profile my code, but am hoping its fairly fast as the costliest step was done in cpython (practically C).

Hope that helps!

weiwenchuan commented 3 years ago

hi @vishwa91 thank you. I just tried it and it works (although the running time is higher than this fast-slic, the segmentation result is really clear). Thank you for sharing your code!