Closed tdmorello closed 2 years ago
I agree, a close code appears here, but I think we need some refactoring. https://github.com/peng-lab/PyBaSiC/blob/01d8fe1ae86f2a09ced2cd710bd3e043036deea0/pybasic/_background.py#L78
It seems that the original algorithm calls SVD only once.
Tim is correct here. I'll run some code profiling to see how much it bogs down the code.
Based on the outputs from the profile, I don't think the svd is too much of a concern. We should take a look at accelerating basic mathematical operations using in-place/accelerated functions where possible. This also suggests that we should get a pretty reasonable speed boost when moving to GPU.
Here is the output from the profiler:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
35 @profile
36 def inexact_alm_rspca_l1(
37 images: np.ndarray,
38 lambda_darkfield: float,
39 lambda_flatfield: float,
40 get_darkfield: bool,
41 optimization_tol: float,
42 max_iterations: int,
43 weight: np.ndarray = None,
44 ):
45
46 2 10.0 5.0 0.0 p = images.shape[0]
47 2 5.0 2.5 0.0 q = images.shape[1]
48 2 6.0 3.0 0.0 m = p * q
49 2 4.0 2.0 0.0 n = images.shape[2]
50 2 12386.0 6193.0 0.2 images = np.reshape(images, (m, n), order="F")
51
52 2 7.0 3.5 0.0 if weight is not None:
53 2 6079.0 3039.5 0.1 weight = np.reshape(weight, (m, n), order="F")
54 else:
55 weight = np.ones_like(images)
56 4 39594.0 9898.5 0.6 svd = np.linalg.svd(
57 2 4.0 2.0 0.0 images, False, False
58 ) # TODO: Is there a more efficient implementation of SVD?
59 2 20.0 10.0 0.0 norm_two = svd[0]
60 2 6.0 3.0 0.0 Y1 = 0
61 2 5.0 2.5 0.0 ent1 = 1
62 2 5.0 2.5 0.0 ent2 = 10
63
64 2 1445.0 722.5 0.0 A1_hat = np.zeros_like(images)
65 2 65.0 32.5 0.0 A1_coeff = np.ones((1, images.shape[1]))
66
67 2 3376.0 1688.0 0.1 E1_hat = np.zeros_like(images)
68 2 1578.0 789.0 0.0 W_hat = dct2d(np.zeros((p, q)).T)
69 2 10.0 5.0 0.0 mu = 12.5 / norm_two
70 2 10.0 5.0 0.0 mu_bar = mu * 1e7
71 2 6.0 3.0 0.0 rho = 1.5
72 2 983.0 491.5 0.0 d_norm = np.linalg.norm(images, ord="fro")
73
74 2 31.0 15.5 0.0 A_offset = np.zeros((m, 1))
75 2 1715.0 857.5 0.0 B1_uplimit = np.min(images)
76 2 7.0 3.5 0.0 B1_offset = 0
77
78 2 28.0 14.0 0.0 A_inmask = np.zeros((p, q))
79 4 37.0 9.2 0.0 A_inmask[
80 4 209.0 52.2 0.0 int(np.round(p / 6) - 1) : int(np.round(p * 5 / 6)),
81 2 104.0 52.0 0.0 int(np.round(q / 6) - 1) : int(np.round(q * 5 / 6)),
82 2 6.0 3.0 0.0 ] = 1
83
84 # main iteration loop starts
85 2 6.0 3.0 0.0 iter = 0
86 2 5.0 2.5 0.0 converged = False
87
88 72 220.0 3.1 0.0 while not converged:
89 70 261.0 3.7 0.0 iter += 1
90
91 70 429.0 6.1 0.0 if len(A1_coeff.shape) == 1:
92 68 4378.0 64.4 0.1 A1_coeff = np.expand_dims(A1_coeff, 0)
93 70 299.0 4.3 0.0 if len(A_offset.shape) == 1:
94 A_offset = np.expand_dims(A_offset, 1)
95 70 44338.0 633.4 0.7 W_idct_hat = idct2d(W_hat.T)
96 70 560000.0 8000.0 8.4 A1_hat = np.dot(np.reshape(W_idct_hat, (-1, 1), order="F"), A1_coeff) + A_offset
97
98 70 844410.0 12063.0 12.6 temp_W = (images - A1_hat - E1_hat + (1 / mu) * Y1) / ent1
99 70 2955.0 42.2 0.0 temp_W = np.reshape(temp_W, (p, q, n), order="F")
100 70 88308.0 1261.5 1.3 temp_W = np.mean(temp_W, axis=2)
101 70 46870.0 669.6 0.7 W_hat = W_hat + dct2d(temp_W.T)
102 140 17213.0 123.0 0.3 W_hat = np.maximum(W_hat - lambda_flatfield / (ent1 * mu), 0) + np.minimum(
103 70 1783.0 25.5 0.0 W_hat + lambda_flatfield / (ent1 * mu), 0
104 )
105 70 53415.0 763.1 0.8 W_idct_hat = idct2d(W_hat.T)
106 70 367.0 5.2 0.0 if len(A1_coeff.shape) == 1:
107 A1_coeff = np.expand_dims(A1_coeff, 0)
108 70 262.0 3.7 0.0 if len(A_offset.shape) == 1:
109 A_offset = np.expand_dims(A_offset, 1)
110 70 600801.0 8582.9 9.0 A1_hat = np.dot(np.reshape(W_idct_hat, (-1, 1), order="F"), A1_coeff) + A_offset
111 70 754335.0 10776.2 11.3 E1_hat = images - A1_hat + (1 / mu) * Y1 / ent1
112 70 2003176.0 28616.8 30.0 E1_hat = _shrinkageOperator(E1_hat, weight / (ent1 * mu))
113 70 395243.0 5646.3 5.9 R1 = images - E1_hat
114 70 194674.0 2781.1 2.9 A1_coeff = np.mean(R1, 0) / np.mean(R1)
115 70 1259.0 18.0 0.0 A1_coeff[A1_coeff < 0] = 0
116
117 70 236.0 3.4 0.0 if get_darkfield:
118 validA1coeff_idx = np.where(A1_coeff < 1)
119
120 B1_coeff = (
121 np.mean(
122 R1[
123 np.reshape(W_idct_hat, -1, order="F")
124 > np.mean(W_idct_hat) - 1e-6
125 ][:, validA1coeff_idx[0]],
126 0,
127 )
128 - np.mean(
129 R1[
130 np.reshape(W_idct_hat, -1, order="F")
131 < np.mean(W_idct_hat) + 1e-6
132 ][:, validA1coeff_idx[0]],
133 0,
134 )
135 ) / np.mean(R1)
136 k = np.array(validA1coeff_idx).shape[1]
137 temp1 = np.sum(A1_coeff[validA1coeff_idx[0]] ** 2)
138 temp2 = np.sum(A1_coeff[validA1coeff_idx[0]])
139 temp3 = np.sum(B1_coeff)
140 temp4 = np.sum(A1_coeff[validA1coeff_idx[0]] * B1_coeff)
141 temp5 = temp2 * temp3 - temp4 * k
142 if temp5 == 0:
143 B1_offset = 0
144 else:
145 B1_offset = (temp1 * temp3 - temp2 * temp4) / temp5
146 # limit B1_offset: 0<B1_offset<B1_uplimit
147
148 B1_offset = np.maximum(B1_offset, 0)
149 B1_offset = np.minimum(B1_offset, B1_uplimit / np.mean(W_idct_hat))
150
151 B_offset = B1_offset * np.reshape(W_idct_hat, -1, order="F") * (-1)
152
153 B_offset = B_offset + np.ones_like(B_offset) * B1_offset * np.mean(
154 W_idct_hat
155 )
156 A1_offset = np.mean(R1[:, validA1coeff_idx[0]], axis=1) - np.mean(
157 A1_coeff[validA1coeff_idx[0]]
158 ) * np.reshape(W_idct_hat, -1, order="F")
159 A1_offset = A1_offset - np.mean(A1_offset)
160 A_offset = A1_offset - np.mean(A1_offset) - B_offset
161
162 # smooth A_offset
163 W_offset = dct2d(np.reshape(A_offset, (p, q), order="F").T)
164 W_offset = np.maximum(
165 W_offset - lambda_darkfield / (ent2 * mu), 0
166 ) + np.minimum(W_offset + lambda_darkfield / (ent2 * mu), 0)
167 A_offset = idct2d(W_offset.T)
168 A_offset = np.reshape(A_offset, -1, order="F")
169
170 # encourage sparse A_offset
171 A_offset = np.maximum(
172 A_offset - lambda_darkfield / (ent2 * mu), 0
173 ) + np.minimum(A_offset + lambda_darkfield / (ent2 * mu), 0)
174 A_offset = A_offset + B_offset
175
176 70 506362.0 7233.7 7.6 Z1 = images - A1_hat - E1_hat
177 70 253366.0 3619.5 3.8 Y1 = Y1 + mu * Z1
178 70 1786.0 25.5 0.0 mu = np.minimum(mu * rho, mu_bar)
179
180 # Stop Criterion
181 70 241313.0 3447.3 3.6 stopCriterion = np.linalg.norm(Z1, ord="fro") / d_norm
182 70 444.0 6.3 0.0 if stopCriterion < optimization_tol:
183 2 7.0 3.5 0.0 converged = True
184
185 70 266.0 3.8 0.0 if not converged and iter >= max_iterations:
186 print("Maximum iterations reached")
187 converged = True
188
189 2 40.0 20.0 0.0 A_offset = np.squeeze(A_offset)
190 2 239.0 119.5 0.0 A_offset = A_offset + B1_offset * np.reshape(W_idct_hat, -1, order="F")
191
192 2 7.0 3.5 0.0 return A1_hat, E1_hat, A_offset
That's surprising, but exciting... looks like we could get a big boost from using jit
?
It's been my general experience that jit really only helps to overcome things that python is intrinsically bad at, like for loops. Things that are vectorized like numpy arrays generally rely on optimized libraries under the hood, so jit generally doesn't help much.
We actually removed jit from a new implementation of flowfield calculations in cellpose because we found a way to vectorize the operations so you didn't need to trace vectors in a for loop. You could just do massive matrix operations and it ran faster than jit.
Seems svd
is not necessary, but np.linalg.norm
can provide the spectral norm.
https://github.com/peng-lab/PyBaSiC/blob/01d8fe1ae86f2a09ced2cd710bd3e043036deea0/pybasic/tools/inexact_alm_rspca_l1.py#L43
I might be missing something, but it looks like this only needs to run once per call to basic. It's input
images
does not change during the course of the run. If that is right, we could calculate the singular values once instead of during every re weighting iteration.