elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.63k stars 190 forks source link

Support common LinAlg primitives #157

Closed seanmor5 closed 2 years ago

seanmor5 commented 3 years ago

For now, at least:

maxastyler commented 3 years ago

I'm wondering if it's best to start name-spacing off these new features instead of putting them in the main Nx module? Numpy's got their numpy.linalg namespace for all these functions.

billylanchantin commented 3 years ago

I agree with @maxastyler. Jax also seems to have that structure:

https://github.com/google/jax/blob/master/jax/_src/lax/linalg.py

Also, is this potentially blocked by #118? For instance, the Cholesky decomposition is supposed to be LL* where L* is the conjugate transpose.

Or, do we want to just deal with the real-valued cases first and come back to the general case as part of #118?

josevalim commented 3 years ago

We can start only real-valued ones. We will have to revisit a bunch of stuff once we add complex numbers.

also, we can always tackle namespacing later too!

polvalente commented 3 years ago

For implementation of things like LU and QR decomposition, should contributors try to write it by hand? I ask this because I looked into the Numpy source code and it basically defers most things to the LAPACK interface.

seanmor5 commented 3 years ago

@polvalente Yes, you'll have to write the implementation by hand. The implementation of Cholesky decomposition is probably something you'll want to look at. Check out: https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/binary_backend.ex to see how we do the work on binaries.

polvalente commented 3 years ago

Ok! Any tips on how to tackle iterating and assigning values to a given column of a matrix?

Yesterday I played around a bit but ended up needing to use a combination of comcatenate, transpose and reshape to convert the list of Q columns and R columns in a Gram-Schmidt decomposition which ended up looking ugly haha

seanmor5 commented 3 years ago

@polvalente My advice is to do something similar Cholesky and use the for ..., reduce: ... pattern: https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/binary_backend.ex#L978

You can also see there how to "index" into the binary based on byte offset. I would try to avoid using concatenate, transpose and reshape as much as possible. Most of those use weighted shapes anyway, so you might just want to take a look at: https://github.com/elixir-nx/nx/blob/main/nx/lib/nx/binary_backend.ex#L1554

josevalim commented 3 years ago

To be clear, there are two ways to implement them:

  1. You can implement it fully only using Nx functions (that's now norm was recently implemented - so that's an example)
  2. The Nx implementation simply invokes a backend which has to implement it at the low-level

It seems QR can be implementable with Nx (that's what JAX does) but we might be missing one or two ops.

polvalente commented 3 years ago

@josevalim @seanmor5 thanks for the tips! Could I perhaps try to tackle QR decomp then?

josevalim commented 3 years ago

All yours!

AlexJuca commented 3 years ago

@josevalim I would like to work on Triangular Solve issue proposed in the list above, any tips before I move forward?

I believe the intention would be to have something akin to numpy's solve function, correct? https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html

seanmor5 commented 3 years ago

@AlexJuca You'll want to look at: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.solve_triangular.html

Don't worry about the lower, unit_diagonal, overwrite_b, or check_finite options for now. I would start with just focusing on solving the system with 2 inputs. The algorithm you're looking for is called forward substitution (for lower triangle matrices) and backward substitution (for upper triangle matrices)

seanmor5 commented 3 years ago

You can implement either forward or backward substitution for now, we can worry about extending it in another PR

AlexJuca commented 3 years ago

@seanmor5 Great tips, I will work on this. Thanks

polvalente commented 3 years ago

@seanmor5 since #265 is on the way, which of the remaining items should I tackle? Or maybe there's another issue in which I could be useful?

josevalim commented 3 years ago

@polvalente feel free to pick any of the ones above still. Notice “norm” is missing nuclear and -2 for 2D matrices.

alternatively, consider implementing gradients and/or the EXLA for QR decomposition. :) thank you for all the help!

josevalim commented 3 years ago

Actually, ignore my second suggestion for now. We need to implement Expr for QR and that is not straight-forward because it returns a tuple. I will have to think about it.

polvalente commented 3 years ago

@polvalente feel free to pick any of the ones above still. Notice “norm” is missing nuclear and -2 for 2D matrices.

alternatively, consider implementing gradients and/or the EXLA for QR decomposition. :) thank you for all the help!

@josevalim all things considered, I think I'll try to tackle SVD. The reason is that both nuclear and -2 norms for 2D matrices depend on having the singular values for said matrix. What do you think?

polvalente commented 3 years ago

@polvalente feel free to pick any of the ones above still. Notice “norm” is missing nuclear and -2 for 2D matrices. alternatively, consider implementing gradients and/or the EXLA for QR decomposition. :) thank you for all the help!

@josevalim all things considered, I think I'll try to tackle SVD. The reason is that both nuclear and -2 norms for 2D matrices depend on having the singular values for said matrix. What do you think?

I've researched a bit and couldn't find a good description of a numerical SVD algorithm. Does anyone know of one we could implement?

josevalim commented 3 years ago

I found a reference implementation on XLA source code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/svd.cc#L784 - the algorithm is always broken in 2 steps but it seems the second step can implemented with higher-level operations using an iterative approach.

I have also looked for "SVD in Python from scratch" and I found some examples - I haven't verified how precise those algorithms are though. So I am not sure if any of this helps.

polvalente commented 3 years ago

I found a reference implementation on XLA source code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/lib/svd.cc#L784 - the algorithm is always broken in 2 steps but it seems the second step can implemented with higher-level operations using an iterative approach.

I have also looked for "SVD in Python from scratch" and I found some examples - I haven't verified how precise those algorithms are though. So I am not sure if any of this helps.

Thanks! I'll take a look at those later today.

polvalente commented 3 years ago

@josevalim @seanmor5 can I take LU decomposition on? I could also try eigen decomposition, but I'm not as familiarized with the algorithms

seanmor5 commented 3 years ago

@polvalente Feel free! :D

justindotpub commented 3 years ago

I'm starting to look at Eigen decomposition, as long as no one else is already taking that.

polvalente commented 3 years ago

I'm starting to look at Eigen decomposition, as long as no one else is already taking that.

Go ahead! Don't worry about complex numbers, because we don't support them yet. Most of this issue will need a "part 2" after complex number support is added anyway

AlexJuca commented 3 years ago

@justincjohnson Alrighty!

justindotpub commented 3 years ago

I'm clearly not moving fast on Eigen decomposition. If anyone is waiting for this, please feel free to take it from me. If it sits longer, I'll jump back in when I have time. Thanks.

polvalente commented 3 years ago

I'm going to start tackling Eigen decomposition

turion commented 3 years ago

I'd propose calculating determinants as a required feature.

polvalente commented 2 years ago

I'd propose calculating determinants as a required feature.

@turion sorry this took too long to respond. I think that determinants could be an useful tool for Nx.LinAlg, although I can't think of any applications outside of calculation eigenvalues or solving equation systems. Do you have any others in mind?

turion commented 2 years ago

Determinants are ubiquitous in linear algebra. I had an application where I was doing PCA and then calculating the density of a Gaussian multivariate distribution. Often one can calculate the determinant from one of the listed decompositions (e.g. Cholesky), so it seems like a low-hanging fruit that would be very easy to implement.

polvalente commented 2 years ago

Determinants are ubiquitous in linear algebra. I had an application where I was doing PCA and then calculating the density of a Gaussian multivariate distribution. Often one can calculate the determinant from one of the listed decompositions (e.g. Cholesky), so it seems like a low-hanging fruit that would be very easy to implement.

It is low hanging indeed. I think it could even be implemented from Nx directly instead of needing backend-specific functions (I'm not sure about this though).

@josevalim WDTY about adding this?

turion commented 2 years ago

I think it could even be implemented from Nx directly

Not sure about that. In some special cases you can use an existing decomposition (e.g. positive semidefinite -> Cholesky), but I think none of the listed decompositions serves as a way to calculate it in every case. Calculating it via Leibniz in elixir is prohibitively slow I believe.

polvalente commented 2 years ago

I think it could even be implemented from Nx directly

Not sure about that. In some special cases you can use an existing decomposition (e.g. positive semidefinite -> Cholesky), but I think none of the listed decompositions serves as a way to calculate it in every case. Calculating it via Leibniz in elixir is prohibitively slow I believe.

iex(1)> t = Nx.tensor([[6, 1, 1], [4, -2, 5], [2, 8, 7]])                                         
#Nx.Tensor<
  s64[3][3]
  [
    [6, 1, 1],
    [4, -2, 5],
    [2, 8, 7]
  ]
>
iex(2)> expected_det = 6 * (-2) * 7 + 1 * 5 * 2 + 1 * 4 * 8 - 1 * 2 * (-2) - 1 * 4 * 7 - 6 * 5 * 8
-306
iex(3)> {l, u} = Nx.LinAlg.lu(t)                                                                  
** (MatchError) no match of right hand side value: {#Nx.Tensor<
   s64[3][3]
   [
     [1, 0, 0],
     [0, 0, 1],
     [0, 1, 0]
   ]
>, #Nx.Tensor<
   f32[3][3]
   [
     [1.0, 0.0, 0.0],
     [0.3333333432674408, 1.0, 0.0],
     [0.6666666865348816, -0.3478260934352875, 1.0]
   ]
>, #Nx.Tensor<
   f32[3][3]
   [
     [6.0, 1.0, 1.0],
     [0.0, 7.666666507720947, 6.666666507720947],
     [0.0, 0.0, 6.65217399597168]
   ]
>}
iex(3)> {p, l, u} = Nx.LinAlg.lu(t)
{#Nx.Tensor<
   s64[3][3]
   [
     [1, 0, 0],
     [0, 0, 1],
     [0, 1, 0]
   ]
>, #Nx.Tensor<
   f32[3][3]
   [
     [1.0, 0.0, 0.0],
     [0.3333333432674408, 1.0, 0.0],
     [0.6666666865348816, -0.3478260934352875, 1.0]
   ]
>, #Nx.Tensor<
   f32[3][3]
   [
     [6.0, 1.0, 1.0], 
     [0.0, 7.666666507720947, 6.666666507720947],
     [0.0, 0.0, 6.65217399597168]
   ]
>}
iex(4)> t
#Nx.Tensor<
  s64[3][3]
  [
    [6, 1, 1],
    [4, -2, 5],
    [2, 8, 7]
  ]
>
iex(5)> diag1 = Nx.select(Nx.eye(t), l, Nx.broadcast(0, t))
#Nx.Tensor<
  f32[3][3]
  [
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0]
  ]
>
iex(6)> diag2 = Nx.select(Nx.eye(t), u, Nx.broadcast(0, t))
#Nx.Tensor<
  f32[3][3]
  [
    [6.0, 0.0, 0.0],
    [0.0, 7.666666507720947, 0.0],
    [0.0, 0.0, 6.65217399597168]
  ]
>
iex(7)> Nx.product(diag1)
#Nx.Tensor<
  f32
  0.0
>
iex(8)> Nx.sum(diag1, axes: [0])
#Nx.Tensor<
  f32[3]
  [1.0, 1.0, 1.0]
>
iex(9)> Nx.sum(diag1, axes: [0]) |> Nx.product()
#Nx.Tensor<
  f32
  1.0
>
iex(10)> Nx.sum(diag1, axes: [0]) |> Nx.product() |> Nx.multiply(Nx.sum(diag2, axes: [0]) |> Nx.product())
#Nx.Tensor<
  f32
  306.0
>

Since the determinant of a product is the product of the determinants, we can almost use LU to calculate it. This also uses the fact that the determinat of a triangular matrix is the product of the diagonal elements. Note that the sign is wrong due to the permutation matrix. Perhaps there's an easier way to determine which sign the matrix has using another process (perhaps a dot product or something could come useful here).

edit: This is slower than calculating it in a backend specific operation, but has the advantage of having the gradient ready for us. Anyway, even if we decide to go on a backend-specific route, this implementation can be useful for binary backend

edit2: From this reference, the determinant of a permutation matrix is -1**num_permutations, so if we can find a way to count the permutations in p, we get the missing factor.

josevalim commented 2 years ago

The way I typically answer these question is: how does Jax implements this functionality? It most likely implements it on top of existing operations, so we should just tag along.

We can allow backends to optimize it later (it is in our backlog to allow "backend overrides").

msluszniak commented 2 years ago

@polvalente I get a solution for your problem. To calculate the (-1)**num_permutations you don't need to calculate the number of permutations, but simple the parity of the number. This is strictly connected with a number of inversions in sequence. In every sequence, if you swap two elements the parity of the number of inversions changes. Thus you need to calculate the number of inversions. I think the easiest way to calculate (-1)**num_permutations is:

The efficient way to calculate the number of inversions is to use an algorithm similar to merge sort O(NlogN), N-number of rows in a matrix. I have an implementation of the algorithm in c++, so I can send you (it's very useful in contest programming).

polvalente commented 2 years ago

@msluszniak thanks for the comment! I ended up finding just this solution in an article and then in the Jax code and added it in #597 and #603