apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

supporting matrix inversion and determinant #14360

Open ketranm opened 5 years ago

ketranm commented 5 years ago

Hi there, I just wonder if we have any plan to support matrix inversion and computing determinant of a matrix in mxnet. These functions are supported in PyTorch (i.e torch.inverse, torch.det, torch.logdet) through MAGMA.

Thanks!

mxnet-label-bot commented 5 years ago

Hey, this is the MXNet Label Bot. Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it. Here are my recommended labels: Feature

stephenrawls commented 5 years ago

+1

For example, these would be useful for these types of models: http://www.aclweb.org/anthology/D07-1015 https://arxiv.org/abs/1702.00887

stephenrawls commented 5 years ago

@mxnet-label-bot Feature

zboldyga commented 5 years ago

@ketranm

Looks like inversion via Cholesky factorization is supported, and there's also an API handle for getting the inversion using that factorization:

https://mxnet.apache.org/api/python/ndarray/linalg.html#linear-algebra

potrf - get the Cholesky factorization (triangular matrix) potri - calculate inversion (edit: using the Cholesky factorization from potrf) sumlogdiag - may be useful for calculating logdeterminant (my linear algebra is a little rusty)

There's not a shortcut for getting the determinant or log determinant, but these are simple ops using the Cholesky factorization.

It seems to me that all of this should be clarified in the documentation, at the minimum, and we should probably add API calls for det and logdet. I've also made a request to have a single 'inverse' operation as with Torch. I opened a JIRA ticket and will start implementing these as soon as someone more internal to the MXNet project signs-off!

arcadiaphy commented 5 years ago

+1

I'm implementing thin plate spline interpolation, calling for an easy to use matrix_inverse. Using potrf and potri is too awkward.

arcadiaphy commented 5 years ago

@zboldyga How about your work on implementing maxtrix operation? I've looked at the source of pytorch, the matrix inversion is implemented using LU factorization (GETRI in lapack). Not very sure if it's more efficient than Cholesky factorization? Also, do you plan to introduce MAGMA to mxnet?

zboldyga commented 5 years ago

@arcadiaphy I got distracted the past few months! I just picked this up today. Looked over the changes and will start coding tomorrow, should have a PR open within a few days.

Immediately I realized that using the Potri and Potrf approach (which relies on LAPACK's Cholesky factorization) only works for symmetric positive-definite matrices. For matrices that meet these criteria, it's (probably) the fastest way to get the inverse.

That said, I'm guessing most cases where the inverse is needed will violate this requirement, so we'll need to create an inverse function based on the LU factorization, which LAPACK also supports (this should be easy). LU factorization is a good, relatively quick approach.

Since the Cholesky operators were already added into the library, I'm going to assume those are needed somewhere. I'll leave those as is.

I'll approach the inverse, determinant, and logdet based on alternate LAPACK routines.

As for MAGMA, I don't think this thread will be the appropriate place to introduce that. This may be a good candidate for discussion on the V1.6 release roadmap (here's V1.5: https://github.com/apache/incubator-mxnet/issues/14619, as of writing V1.5 is closed and V1.6 will probably open soon).

arcadiaphy commented 5 years ago

@zboldyga I've worked on it myself recently, you can have a look on my PR: #14963 and #15007

zboldyga commented 5 years ago

@arcadiaphy OK great, thanks for letting me know!

Do you think the low-level implementations are going to be merged, or do the ops need to be written at the Python level?

If your approach is going to be merged, do you need a hand with the last steps: piverse or svd?

arcadiaphy commented 5 years ago

@zboldyga I think high-level ops need to be implemented in c++ for the consideration of backward pass, meanwhile we can expose enough low-level ops for users.

I'll try to merge the determinant PR. Welcome any help in piverse and svd, I'm really busy these days.

zboldyga commented 5 years ago

@arcadiaphy OK great. Moving this discussion about svd and pinverse over to issue #14962