pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
325 stars 97 forks source link

Add missing numba ops #161

Open twiecki opened 1 year ago

twiecki commented 1 year ago

Description

We're fairly close to having full Numba support, but a few important numba issues are missing. This is an incomplete list that we should complete and then make a push to add them.

Here is a tutorial on how to add them: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

lucianopaz commented 1 year ago

In aesara this was being tracked in this milestone. The particular COps were being tracked in this issue. There's also the very important fact that we need to use numba's numpy compatible random Generator objects.

twiecki commented 1 year ago

I think replacing C is too lofty of a goal for now. Instead, we should focus on making all of PyMC work with numba.

fonnesbeck commented 1 year ago

erfcx is also missing

ricardoV94 commented 1 year ago

erfcx is also missing

Doesn't it just default to the Python implementation?

ricardoV94 commented 1 year ago

LogSumExp (Add numba implementations for LogSumExp (reference implementation exists) aesara-devs/aesara#404 for code example)

We don't have a LogSumExp last time I checked. It's built from other Ops, so we shouldn't need to implement it

jessegrabowski commented 1 year ago

I want to shill the work I did on numba links for several linear algebra operations again:

https://github.com/numba/numba-scipy/commit/462191b4f745ed260056c534e28e8e0ba1a743a5

A core problem is that the numba.np.linalg module simply doesn't have links to several LAPACK functions we care about, including (but not limited to) SolveTriangular. There's also zero support for scipy.linalg functions. I show in that code it's not a big deal to write the hooks, but it's also not clear (to me anyway) where they should go in the code base. Numba-scipy is the most natural place, but I don't know if it's being actively developed beyond the scipy.experimental module.

twiecki commented 1 year ago

@jessegrabowski Looks interesting! You're just linking to a commit, is that in main or is that part of an outstanding PR? What's the method you propose we integrate this?

jessegrabowski commented 1 year ago

It's an outstanding PR that didn't garner any attention.

I have no idea the best way to integrate it. I use this code in one of my own projects, where I just shoved it into a sub-module and added an entry point for the numba overloads into my setup.py. That worked fine and lets me use the overloaded functions under @njit decorators. I don't think it's a very "principled" approach, though.

aseyboldt commented 1 year ago

@jessegrabowski I think we could just put that code in linker/numba/dispatch/linalg.py. We don't need to override the scipy functions (that would change the behavior of completely unrelated code, just because someone imports pytensor), but just provide our own linalg functions for the ops.

aseyboldt commented 1 year ago

erfcx works for me, that was added here: #46

aseyboldt commented 1 year ago

logsumexp is currently build from other parts: image Maybe we could improve performance of this further, but for now I think this is fine.

mtsokol commented 1 year ago

Hi @twiecki! When it comes to Det/Logdet from the list it looks that Det is already available, is that right? https://github.com/pymc-devs/pytensor/blob/38dc6c9f60c45bf8d00b7201bc64139ea88c0132/pytensor/link/numba/dispatch/nlinalg.py#L48

By Logdet do you mean slogdet?
https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html#numpy.linalg.slogdet

[EDIT] Opened a WIP PR for the latter: https://github.com/pymc-devs/pytensor/pull/172

twiecki commented 1 year ago

@mtsokol That looks right -- thanks for opening a PR!

ricardoV94 commented 8 months ago

Some cases of AdvancedSubtensor can be supported by clever reshaping, and indexing based on strides.

Snippet we were using sometime ago, and pasted here completely out of context:

x, y = design_matrix.nonzero()
*s1, s2, s3 = result.shape
return result.reshape(*s1, s2*s3)[..., x*s3 + y]