Closed seanmor5 closed 2 years ago
I'm wondering if it's best to start name-spacing off these new features instead of putting them in the main Nx module? Numpy's got their numpy.linalg namespace for all these functions.
I agree with @maxastyler. Jax also seems to have that structure:
https://github.com/google/jax/blob/master/jax/_src/lax/linalg.py
Also, is this potentially blocked by #118? For instance, the Cholesky decomposition is supposed to be LL*
where L*
is the conjugate transpose.
Or, do we want to just deal with the real-valued cases first and come back to the general case as part of #118?
We can start only real-valued ones. We will have to revisit a bunch of stuff once we add complex numbers.
also, we can always tackle namespacing later too!
For implementation of things like LU and QR decomposition, should contributors try to write it by hand? I ask this because I looked into the Numpy source code and it basically defers most things to the LAPACK interface.
@polvalente Yes, you'll have to write the implementation by hand. The implementation of Cholesky decomposition is probably something you'll want to look at. Check out: https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/binary_backend.ex to see how we do the work on binaries.
Ok! Any tips on how to tackle iterating and assigning values to a given column of a matrix?
Yesterday I played around a bit but ended up needing to use a combination of comcatenate
, transpose
and reshape
to convert the list of Q columns and R columns in a Gram-Schmidt decomposition which ended up looking ugly haha
@polvalente My advice is to do something similar Cholesky and use the for ..., reduce: ...
pattern: https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/binary_backend.ex#L978
You can also see there how to "index" into the binary based on byte offset. I would try to avoid using concatenate
, transpose
and reshape
as much as possible. Most of those use weighted shapes anyway, so you might just want to take a look at: https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/binary_backend.ex#L1554
To be clear, there are two ways to implement them:
norm
was recently implemented - so that's an example)It seems QR can be implementable with Nx (that's what JAX does) but we might be missing one or two ops.
@josevalim @seanmor5 thanks for the tips! Could I perhaps try to tackle QR decomp then?
All yours!
@josevalim I would like to work on Triangular Solve issue proposed in the list above, any tips before I move forward?
I believe the intention would be to have something akin to numpy's solve
function, correct?
https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
@AlexJuca You'll want to look at: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.solve_triangular.html
Don't worry about the lower
, unit_diagonal
, overwrite_b
, or check_finite
options for now. I would start with just focusing on solving the system with 2 inputs. The algorithm you're looking for is called forward substitution (for lower triangle matrices) and backward substitution (for upper triangle matrices)
You can implement either forward or backward substitution for now, we can worry about extending it in another PR
@seanmor5 Great tips, I will work on this. Thanks
@seanmor5 since #265 is on the way, which of the remaining items should I tackle? Or maybe there's another issue in which I could be useful?
@polvalente feel free to pick any of the ones above still. Notice “norm” is missing nuclear and -2 for 2D matrices.
alternatively, consider implementing gradients and/or the EXLA for QR decomposition. :) thank you for all the help!
Actually, ignore my second suggestion for now. We need to implement Expr for QR and that is not straight-forward because it returns a tuple. I will have to think about it.
@polvalente feel free to pick any of the ones above still. Notice “norm” is missing nuclear and -2 for 2D matrices.
alternatively, consider implementing gradients and/or the EXLA for QR decomposition. :) thank you for all the help!
@josevalim all things considered, I think I'll try to tackle SVD. The reason is that both nuclear and -2 norms for 2D matrices depend on having the singular values for said matrix. What do you think?
@polvalente feel free to pick any of the ones above still. Notice “norm” is missing nuclear and -2 for 2D matrices. alternatively, consider implementing gradients and/or the EXLA for QR decomposition. :) thank you for all the help!
@josevalim all things considered, I think I'll try to tackle SVD. The reason is that both nuclear and -2 norms for 2D matrices depend on having the singular values for said matrix. What do you think?
I've researched a bit and couldn't find a good description of a numerical SVD algorithm. Does anyone know of one we could implement?
I found a reference implementation on XLA source code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/svd.cc#L784 - the algorithm is always broken in 2 steps but it seems the second step can implemented with higher-level operations using an iterative approach.
I have also looked for "SVD in Python from scratch" and I found some examples - I haven't verified how precise those algorithms are though. So I am not sure if any of this helps.
I found a reference implementation on XLA source code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/svd.cc#L784 - the algorithm is always broken in 2 steps but it seems the second step can implemented with higher-level operations using an iterative approach.
I have also looked for "SVD in Python from scratch" and I found some examples - I haven't verified how precise those algorithms are though. So I am not sure if any of this helps.
Thanks! I'll take a look at those later today.
@josevalim @seanmor5 can I take LU decomposition on? I could also try eigen decomposition, but I'm not as familiarized with the algorithms
@polvalente Feel free! :D
I'm starting to look at Eigen decomposition, as long as no one else is already taking that.
I'm starting to look at Eigen decomposition, as long as no one else is already taking that.
Go ahead! Don't worry about complex numbers, because we don't support them yet. Most of this issue will need a "part 2" after complex number support is added anyway
@justincjohnson Alrighty!
I'm clearly not moving fast on Eigen decomposition. If anyone is waiting for this, please feel free to take it from me. If it sits longer, I'll jump back in when I have time. Thanks.
I'm going to start tackling Eigen decomposition
I'd propose calculating determinants as a required feature.
I'd propose calculating determinants as a required feature.
@turion sorry this took too long to respond. I think that determinants could be an useful tool for Nx.LinAlg
, although I can't think of any applications outside of calculation eigenvalues or solving equation systems. Do you have any others in mind?
Determinants are ubiquitous in linear algebra. I had an application where I was doing PCA and then calculating the density of a Gaussian multivariate distribution. Often one can calculate the determinant from one of the listed decompositions (e.g. Cholesky), so it seems like a low-hanging fruit that would be very easy to implement.
Determinants are ubiquitous in linear algebra. I had an application where I was doing PCA and then calculating the density of a Gaussian multivariate distribution. Often one can calculate the determinant from one of the listed decompositions (e.g. Cholesky), so it seems like a low-hanging fruit that would be very easy to implement.
It is low hanging indeed. I think it could even be implemented from Nx directly instead of needing backend-specific functions (I'm not sure about this though).
@josevalim WDTY about adding this?
I think it could even be implemented from Nx directly
Not sure about that. In some special cases you can use an existing decomposition (e.g. positive semidefinite -> Cholesky), but I think none of the listed decompositions serves as a way to calculate it in every case. Calculating it via Leibniz in elixir is prohibitively slow I believe.
I think it could even be implemented from Nx directly
Not sure about that. In some special cases you can use an existing decomposition (e.g. positive semidefinite -> Cholesky), but I think none of the listed decompositions serves as a way to calculate it in every case. Calculating it via Leibniz in elixir is prohibitively slow I believe.
iex(1)> t = Nx.tensor([[6, 1, 1], [4, -2, 5], [2, 8, 7]])
#Nx.Tensor<
s64[3][3]
[
[6, 1, 1],
[4, -2, 5],
[2, 8, 7]
]
>
iex(2)> expected_det = 6 * (-2) * 7 + 1 * 5 * 2 + 1 * 4 * 8 - 1 * 2 * (-2) - 1 * 4 * 7 - 6 * 5 * 8
-306
iex(3)> {l, u} = Nx.LinAlg.lu(t)
** (MatchError) no match of right hand side value: {#Nx.Tensor<
s64[3][3]
[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]
]
>, #Nx.Tensor<
f32[3][3]
[
[1.0, 0.0, 0.0],
[0.3333333432674408, 1.0, 0.0],
[0.6666666865348816, -0.3478260934352875, 1.0]
]
>, #Nx.Tensor<
f32[3][3]
[
[6.0, 1.0, 1.0],
[0.0, 7.666666507720947, 6.666666507720947],
[0.0, 0.0, 6.65217399597168]
]
>}
iex(3)> {p, l, u} = Nx.LinAlg.lu(t)
{#Nx.Tensor<
s64[3][3]
[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]
]
>, #Nx.Tensor<
f32[3][3]
[
[1.0, 0.0, 0.0],
[0.3333333432674408, 1.0, 0.0],
[0.6666666865348816, -0.3478260934352875, 1.0]
]
>, #Nx.Tensor<
f32[3][3]
[
[6.0, 1.0, 1.0],
[0.0, 7.666666507720947, 6.666666507720947],
[0.0, 0.0, 6.65217399597168]
]
>}
iex(4)> t
#Nx.Tensor<
s64[3][3]
[
[6, 1, 1],
[4, -2, 5],
[2, 8, 7]
]
>
iex(5)> diag1 = Nx.select(Nx.eye(t), l, Nx.broadcast(0, t))
#Nx.Tensor<
f32[3][3]
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]
]
>
iex(6)> diag2 = Nx.select(Nx.eye(t), u, Nx.broadcast(0, t))
#Nx.Tensor<
f32[3][3]
[
[6.0, 0.0, 0.0],
[0.0, 7.666666507720947, 0.0],
[0.0, 0.0, 6.65217399597168]
]
>
iex(7)> Nx.product(diag1)
#Nx.Tensor<
f32
0.0
>
iex(8)> Nx.sum(diag1, axes: [0])
#Nx.Tensor<
f32[3]
[1.0, 1.0, 1.0]
>
iex(9)> Nx.sum(diag1, axes: [0]) |> Nx.product()
#Nx.Tensor<
f32
1.0
>
iex(10)> Nx.sum(diag1, axes: [0]) |> Nx.product() |> Nx.multiply(Nx.sum(diag2, axes: [0]) |> Nx.product())
#Nx.Tensor<
f32
306.0
>
Since the determinant of a product is the product of the determinants, we can almost use LU to calculate it. This also uses the fact that the determinat of a triangular matrix is the product of the diagonal elements. Note that the sign is wrong due to the permutation matrix. Perhaps there's an easier way to determine which sign the matrix has using another process (perhaps a dot product or something could come useful here).
edit: This is slower than calculating it in a backend specific operation, but has the advantage of having the gradient ready for us. Anyway, even if we decide to go on a backend-specific route, this implementation can be useful for binary backend
edit2: From this reference, the determinant of a permutation matrix is -1**num_permutations
, so if we can find a way to count the permutations in p
, we get the missing factor.
The way I typically answer these question is: how does Jax implements this functionality? It most likely implements it on top of existing operations, so we should just tag along.
We can allow backends to optimize it later (it is in our backlog to allow "backend overrides").
@polvalente I get a solution for your problem. To calculate the (-1)**num_permutations
you don't need to calculate the number of permutations, but simple the parity of the number. This is strictly connected with a number of inversions in sequence. In every sequence, if you swap two elements the parity of the number of inversions changes. Thus you need to calculate the number of inversions. I think the easiest way to calculate (-1)**num_permutations
is:
(-1)**num_permutations
= (-1)**2 = 1The efficient way to calculate the number of inversions is to use an algorithm similar to merge sort O(NlogN), N-number of rows in a matrix. I have an implementation of the algorithm in c++, so I can send you (it's very useful in contest programming).
@msluszniak thanks for the comment! I ended up finding just this solution in an article and then in the Jax code and added it in #597 and #603
For now, at least: