yngvem / group-lasso

Group Lasso implementation following the scikit-learn API
MIT License
105 stars 32 forks source link

How to do stain unmix with group lasso #31

Closed SikangSHU closed 2 years ago

SikangSHU commented 2 years ago

Hello! Do you know how to do stain unmixing for RGB images when the number of stains is more than three, just as the way the paper introduces. Group sparsity model for stain unmixing in brightfield multipleximmunohistochemistry images.pdf The following code is used to do stain unmixing for 4 colors. But the method is not good.

import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from numpy import linalg
from skimage.color.colorconv import _prepare_colorarray
from sklearn.metrics import mean_squared_error

img = io.imread('E:\\test1.png')   
img1 = _prepare_colorarray(img, force_copy=True)    
np.maximum(img1, 1E-6, out=img1)    
y = np.log(img1)

X1 = np.array([[0.571, 0.584, 0.577], [0.095, 0.258, 0.961], [0.767, 0.576, 0.284]])  # CD8,PanCK,Hema
X1_inv = linalg.inv(X1)       
B1 = y @ X1_inv
y1 = B1 @ X1

X2 = np.array([[0.095, 0.258, 0.961], [0.105, 0.758, 0.644], [0.767, 0.576, 0.284]])  # PanCK,PD-L1,Hema
X2_inv = linalg.inv(X2)       
B2 = y @ X2_inv
y2 = B2 @ X2

X3 = np.array([[0.571, 0.584, 0.577], [0.767, 0.576, 0.284], [-0.48, 0.808, -0.343]])  # CD8,Hema
X3_inv = linalg.inv(X3)       
B3 = y @ X3_inv
y3 = B3 @ X3

X4 = np.array([[0.095, 0.258, 0.961], [0.767, 0.576, 0.284], [-0.553, 0.817, -0.165]])  # PanCK,Hema
X4_inv = linalg.inv(X4)       
B4 = y @ X4_inv
y4 = B4 @ X4

X5 = np.array([[0.105, 0.758, 0.644], [0.767, 0.576, 0.284], [-0.218, 0.649, -0.729]])  # PDL1,Hema
X5_inv = linalg.inv(X5)       
B5 = y @ X5_inv
y5 = B5 @ X5

a = 0
b = 0
rgb_CD8 = np.zeros_like(y) + 1
rgb_PanCK = np.zeros_like(y) + 1
rgb_Hema = np.zeros_like(y) + 1
rgb_PDL1 = np.zeros_like(y) + 1
for i in y:
    for j in i:
        e = y[a, b, :]
        e1 = y1[a, b, :]
        e2 = y2[a, b, :]
        e3 = y3[a, b, :]
        e4 = y4[a, b, :]
        e5 = y5[a, b, :]
        p1 = mean_squared_error(e, e1)
        p2 = mean_squared_error(e, e2)
        p3 = mean_squared_error(e, e3)
        p4 = mean_squared_error(e, e4)
        p5 = mean_squared_error(e, e5)
        if p1 > p2 and p3 > p2 and p4 > p2 and p5 > p2:
            null = np.zeros_like(B2[:, :, 0])
            B2_A = np.stack((B2[:, :, 0], null, null), axis=-1)  
            B2_B = np.stack((null, B2[:, :, 1], null), axis=-1)  
            B2_C = np.stack((null, null, B2[:, :, 2]), axis=-1)  
            conv_matrix = X2
            log_rgb21 = B2_A[a][b] @ conv_matrix
            rgb_PanCK[a][b] = np.exp(log_rgb21)
            log_rgb22 = B2_B[a][b] @ conv_matrix
            rgb_PDL1[a][b] = np.exp(log_rgb22)
            log_rgb23 = B2_C[a][b] @ conv_matrix
            rgb_Hema[a][b] = np.exp(log_rgb23)
        elif p2 > p1 and p3 > p1 and p4 > p1 and p5 > p1:
            null = np.zeros_like(B1[:, :, 0])
            B1_A = np.stack((B1[:, :, 0], null, null), axis=-1)  
            B1_B = np.stack((null, B1[:, :, 1], null), axis=-1)  
            B1_C = np.stack((null, null, B1[:, :, 2]), axis=-1)  
            conv_matrix = X1
            log_rgb11 = B1_A[a][b] @ conv_matrix
            rgb_CD8[a][b] = np.exp(log_rgb11)
            log_rgb12 = B1_B[a][b] @ conv_matrix
            rgb_PanCK[a][b] = np.exp(log_rgb12)
            log_rgb13 = B1_C[a][b] @ conv_matrix
            rgb_Hema[a][b] = np.exp(log_rgb13)
        elif p1 > p3 and p2 > p3 and p4 > p3 and p5 > p3:
            null = np.zeros_like(B3[:, :, 0])
            B3_A = np.stack((B3[:, :, 0], null, null), axis=-1)  
            B3_B = np.stack((null, B3[:, :, 1], null), axis=-1)  
            conv_matrix = X3
            log_rgb31 = B3_A[a][b] @ conv_matrix
            rgb_CD8[a][b] = np.exp(log_rgb31)
            log_rgb32 = B3_B[a][b] @ conv_matrix
            rgb_Hema[a][b] = np.exp(log_rgb32)
        elif p1 > p4 and p2 > p4 and p3 > p4 and p5 > p4:
            null = np.zeros_like(B4[:, :, 0])
            B4_A = np.stack((B4[:, :, 0], null, null), axis=-1)  
            B4_B = np.stack((null, B4[:, :, 1], null), axis=-1)  
            conv_matrix = X4
            log_rgb41 = B4_A[a][b] @ conv_matrix
            rgb_PanCK[a][b] = np.exp(log_rgb41)
            log_rgb42 = B4_B[a][b] @ conv_matrix
            rgb_Hema[a][b] = np.exp(log_rgb42)
        else:
            null = np.zeros_like(B5[:, :, 0])
            B5_A = np.stack((B5[:, :, 0], null, null), axis=-1)  
            B5_B = np.stack((null, B5[:, :, 1], null), axis=-1)  
            conv_matrix = X5
            log_rgb51 = B5_A[a][b] @ conv_matrix
            rgb_PDL1[a][b] = np.exp(log_rgb51)
            log_rgb42 = B5_B[a][b] @ conv_matrix
            rgb_Hema[a][b] = np.exp(log_rgb42)
        b = b + 1
    b = 0
    a = a + 1

fig, axes = plt.subplots(3, 2, figsize=(8, 7), sharex=True, sharey=True)
ax = axes.ravel()

ax[0].imshow(img)
ax[0].set_title("Original image")

rgb_Hema[rgb_Hema < 0] = 0
rgb_Hema[rgb_Hema > 1] = 1
ax[2].imshow(rgb_Hema)
ax[2].set_title("Hematoxylin")

rgb_PanCK[rgb_PanCK < 0] = 0
rgb_PanCK[rgb_PanCK > 1] = 1
ax[3].imshow(rgb_PanCK)
ax[3].set_title("PanCK")

rgb_PDL1[rgb_PDL1 < 0] = 0
rgb_PDL1[rgb_PDL1 > 1] = 1
ax[4].imshow(rgb_PDL1)
ax[4].set_title("PDL1")

rgb_CD8[rgb_CD8 < 0] = 0
rgb_CD8[rgb_CD8 > 1] = 1
ax[5].imshow(rgb_CD8)
ax[5].set_title("CD8")

for a in ax.ravel():
    a.axis('off')

fig.tight_layout()
plt.show()
SikangSHU commented 2 years ago

More specificially, the dimension of data is 3.

yngvem commented 2 years ago

This is not really an issue with this codebase, so I'm closing the issue.