Open FedericoV opened 2 years ago
Do you mean something like this? https://github.com/google/jax/pull/2220
Yes, that’s exactly it. I was hit by a NaN when backpropagating through a least square solution and noticed the older comment string so I assume it was a low hanging fruit that still needed fixing.
My bad for opening this issue: I think the previous issue didn’t mention leastsq exactly so I missed it when searching for related issues.
The only suggestion I have is let’s remove the old stranded comment string since it no longer applies?
On Tue, May 24, 2022 at 5:12 AM clemisch @.***> wrote:
Do you mean something like this? #2220 https://github.com/google/jax/pull/2220
— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/10805#issuecomment-1135842339, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAEZ25SPFIOJ5BLBLSEMR2LVLTBUPANCNFSM5WXQIONA . You are receiving this because you authored the thread.Message ID: @.***>
I think the issue is still open, i.e., we need a custom gradient for least-squares? Differentiating through the SVD is problematic. If it were possible to express that gradient as a custom linear solve, that would probably be a good idea.
We would welcome a PR!
Am I correct and this is still ongoing? If not, You can simply ignore my writing.
So, if I'm correct, You need a derivative for the vector x with respect to A and b that solves the least squares problem
min{sum((b - Ax)²); x}.
I currently had to get something like this to compute the Jacobian for a constraint that constrained the solution x to be only positive/negative individually for each x_j and I have found something here (I think the formula has typos and there should be identity matrices instead of ones in the first two equations).
The solution for x is given by x = A^+ @ b where A^+ denotes the Moore-Penrose Pseudoinverse of A. It can be computed via SVD. Its computation is demanding and on some websites it was claimed a cardinal sin to compute it if one is interested only in x. But as far as I can see from the source code, the _lstsq
in jax readily computes the SVD of A by
u, s, vt = svd(a, full_matrices=False)
so the pseudoinversion is very much pre-computed and only some intermediate results would need to be recycled.
Once, You have A^+ at hand it is "quite simple" to compute the derivatives of x with respect to A and b if the rank of A can be considered constant: dx / db = A^+ dx / dA = (dA^+ / dA) @ b with dA^+ / dA given by the equations from the link above.
I wrote a dummy test script in NumPy which compares the numerical derivavites with the analytical ones. It works reasonably well and I've already used it successfully in mixed linear-and-nonlinear-fits that involved a lstsq
at each step of the nonlinear iteration to find the linear coefficients.
For the numerical derivatives I added a small summand called shift
and then I divided the difference in the solutions by it (finite difference). I have shown different cases like a full shift in b, some changes in b, some changes in A and even a column of A being a function of an altered parameter.
import numpy as np
# setting up a signal b(a) = 2 + 3 * a - 0.4 * a²
a_mat = np.linspace(0.001, 10., 101).reshape((-1, 1)) ** np.arange(0, 3, 1).reshape((1, -1))
np.random.seed(0)
b_vect = a_mat @ np.array([2., 3., -0.4]) + np.random.normal(loc = 0., scale = 0.5, size = (a_mat.shape[0],))
# calculating the coefficient vector as x = A^+ @ b
orig_pinv = np.linalg.pinv(a_mat)
orig_coeff_vect = orig_pinv @ b_vect # returns x = [ 2.54618329 2.71138515 -0.37206503]
# --------------------------------------------------------------------------------------------------------------------------------
# Derivative with respect to b
# the derivative dx/db = A^+ which can be shown by shifting b by a very small shift
shift = 1e-8
shift_b_vect = b_vect + shift
shift_b_coeff_vect = orig_pinv @ shift_b_vect # returns x = [ 2.54647197 2.7113004 -0.37205944]
approx_deriv_shift_b = (shift_b_coeff_vect - orig_coeff_vect) / shift
# returns dx/db ~= [ 9.99999905e-01 -4.44089210e-08 0.00000000e+00]
# this is expected since this is the row-wise sum of A^+
true_deriv_shift_b = np.sum(orig_pinv, axis = 1) # returns dx/db = [1.00000000e+00 1.66533454e-16 3.07913417e-17]
print(shift_b_coeff_vect)
print(approx_deriv_shift_b)
print(true_deriv_shift_b)
#.................................................................................................................................
# in case only few entries of b are shifted, the row-wise sums of the respective entries of A^+ are the derivatives
shift_b_vect = b_vect.copy()
shift_b_idxs = [20, 22, 30, 51, 76, 90]
shift_b_vect[shift_b_idxs] += shift
shift_b_coeff_vect = orig_pinv @ shift_b_vect # returns x = [ 2.54647196 2.7113004 -0.37205944]
approx_deriv_shift_b = (shift_b_coeff_vect - orig_coeff_vect) / shift
# returns dx/db ~= [ 0.04490786 0.01138494 -0.00126651]
# this is expected since this is the row-wise sum of the respective columns of A^+
true_deriv_shift_b = np.sum(np.take(orig_pinv, shift_b_idxs, axis = 1), axis = 1)
# returns dx/db = [ 0.04490779 0.01138497 -0.00126651]
print(shift_b_coeff_vect)
print(approx_deriv_shift_b)
print(true_deriv_shift_b)
# --------------------------------------------------------------------------------------------------------------------------------
# Derivative with respect to A
# this case is more complicated. Since x = A^+ @ b this means that dx/dA = dA^+/dA @ b
# let us denote A^+ by P and A = f(a)
# dP/da is given by -P @ (dA/da) @ P + P @ P.T @ (dA.T/da) @ (I - A @ P) + (I - P @ A) @ (dA.T/da) @ P.T @ P
# as an example, some entries of A are shifted (their dA/da = 1 and dA.T/da = (dA/da).T)
shift_a_mat = a_mat.copy()
shift_idx_list = [[20, 0], [51, 2], [77, 1]]
for shift_idx in shift_idx_list:
shift_a_mat[shift_idx[0], shift_idx[1]] += shift
# the new solution is now obtained and the derivative is approximated
shift_a_coeff_vect = np.linalg.pinv(shift_a_mat) @ b_vect # returns x = [ 2.54647196 2.7113004 -0.37205944]
approx_deriv_shift_a = (shift_a_coeff_vect - orig_coeff_vect) / shift
# returns dx/dA ~= [-0.15975807 0.0430922 -0.00350816]
# (for the combination of specific A(i,j) shifted!!!)
# and for comparison, the derivative is computed exactly
# this needs the derivatives of dA/da to be known which are 1 for the entries shifted and 0 otherwise
deriv_of_a = np.zeros_like(a_mat)
shift_idx_list = [[20, 0], [51, 2], [77, 1]]
for shift_idx in shift_idx_list:
deriv_of_a[shift_idx[0], shift_idx[1]] = 1.
# these derivatives can now be utilised to compute the derivative of the Pseudoinverse
shift_a_deriv_pinv = - orig_pinv @ deriv_of_a @ orig_pinv + \
orig_pinv @ orig_pinv.T @ deriv_of_a.T @ (np.eye(b_vect.size) - a_mat @ orig_pinv) + \
(np.eye(a_mat.shape[1]) - orig_pinv @ a_mat) @ deriv_of_a.T @ orig_pinv.T @ orig_pinv
true_deriv_shift_a = shift_a_deriv_pinv @ b_vect # returns dx/dA = [-0.15976061 0.04309242 -0.00350813]
# (for the combination of specific A(i,j) shifted!!!)
print(shift_a_coeff_vect)
print(approx_deriv_shift_a)
print(true_deriv_shift_a)
#.................................................................................................................................
# since the solution works, let's go a step further and compute the derivative of x with respect to the second power ...
# in A which is shifted from a^2 to a^(2.00000001)
# the respective derivative of A[::, 2] with respect to this power is ln(a) * (a²)
shift_a_mat = a_mat.copy()
shift_a_mat[::, 2] **= (2. + shift) / 2.
# the new solution is now obtained and the derivative is approximated
shift_a_coeff_vect = np.linalg.pinv(shift_a_mat) @ b_vect # returns x = [ 2.54647198 2.71130037 -0.37205943]
approx_deriv_shift_a = (shift_a_coeff_vect - orig_coeff_vect) / shift
# returns dx/dA ~= [ 1.80316184 -2.9622504 1.12283381]
# (for the column of specific A altered!!!)
# and for comparison, the derivative is computed exactly
# this needs the derivatives of dA/da to be known which are ln(a) * (a²) for the entries shifted and 0 otherwise
deriv_of_a = np.zeros_like(a_mat)
deriv_of_a[::, 2] = np.log(a_mat[::, 1]) * a_mat[::, 2]
# these derivatives can now be utilised to compute the derivative of the Pseudoinverse
shift_a_deriv_pinv = - orig_pinv @ deriv_of_a @ orig_pinv + \
orig_pinv @ orig_pinv.T @ deriv_of_a.T @ (np.eye(b_vect.size) - a_mat @ orig_pinv) + \
(np.eye(a_mat.shape[1]) - orig_pinv @ a_mat) @ deriv_of_a.T @ orig_pinv.T @ orig_pinv
true_deriv_shift_a = shift_a_deriv_pinv @ b_vect # returns dx/dA = [ 1.80315434 -2.96225005 1.12283389]
# (for the column of specific A altered!!!)
print(shift_a_coeff_vect)
print(approx_deriv_shift_a)
print(true_deriv_shift_a)
Note: this can be extended in a variety of ways, e.g., rows in A as functions of an altered parameter. Yet, it needs to be mentoied that dA can be a very sparse matrix. If only a single column or row is affected, this can easily be done by ignoring all the zero-values in the dot-products appropriately.
Since the inclusion of inverse matrices was already a suggestion and this is basically "Differentiating through the SVD", I hope this helps! Best regards
Currently, gradients through lstsq are quite brittle because there is no custom derivative attached - this is described in the code since it was first implemented by @jakevdp here: https://github.com/google/jax/blob/ffe881cd5862480a4110a6f9076bd89dc995f012/jax/_src/numpy/linalg.py#L549
The easiest possible gradient would be to use the implicit function theorem to derive the gradient. For a system like:
A@x + b = 0, the gradients would end up being (I think - please double check my algebra):
Of course, directly using the inverse is probably not the most numerically stable thing we can do, but, is this a good starting point?