scientific-python / summit-2024

1 stars 0 forks source link

PyTorch as the core compute layer #18

Open thomasjpfan opened 3 months ago

thomasjpfan commented 3 months ago

During Array API adoption and from conversations with SciPy devs, I've seen attempts to dispatch to underlying libraries for compute. For example: https://github.com/scipy/scipy/pull/20772. The motivation for this special casing is that the Array API standard does not container all the required APIs. Logistically, I think the standard will always lag behind what libraries offer and there will be some APIs that may never be standardized.

An alternative proposal is to use dlpack to do a zero copy transfer to PyTorch, use the PyTorch API for compute, and use dlpack transfer back to the original array container. Here are the pros and cons:

Pros

  1. We can use the full PyTorch API without the limitations of the Array API Standard.
  2. We get all the advantages of PyTorch, such as torch.compile or torch.export in the future.

Cons

  1. PyTorch feels more corporate compared to NumPy
  2. Array API Standard covers Jax, Dask, and all future Array libraries that adopt the standard. Going with PyTorch as the core compute layer would reduce coverage.
  3. PyTorch cpu wheel is ~ 183 MB, which is much bigger than NumPy.

Currently, I am -0 on such a move.

mdhaber commented 3 months ago

IIUC, another motivation for the suggestion is that the transfer could be zero-copy regardless of which device the data is on, whereas using NumPy requires transfer to CPU and back when the original data is elsewhere?

In some cases, this would work and be useful. I'm not sure if there would be an advantage in the special equivalent of that PR, though. There, we would still want to fall back to NumPy/scipy.special, since PyTorch doesn't have nearly the same coverage of special functions at the moment, right? (It's looking like there will soon be a separate special function library that can be used on CPU or GPU, though, so this is changing.)

But maybe I don't understand the proposal - is it broader than cases like scipy/scipy#20772, which mosty tries to dispatch to existing functions in special libraries if they have a direct equivalent? What about cases like adding array API support to scipy.stats? There, we are not just dispatching calls of entire scipy.stats functions to other libraries; rather, we're replacing NumPy API calls within scipy.stats functions with array API standard calls. Would this proposal suggest replacing NumPy API calls with PyTorch calls, instead?

ev-br commented 3 months ago

ISTM a simple version would be to build a torch_scipy (a made-up name) collection of routines with the SciPy API and torch native implementations. Then SciPy at least can dispatch to it. Would remove the more corporate concern, too. Bits and pieces exist in torch itself and elsewhere, e.g. https://github.com/lezcano/expm/tree/master/pytorch_expm and https://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py

steppi commented 3 months ago

ISTM a simple version would be to build a torch_scipy (a made-up name) collection of routines with the SciPy API and torch native implementations. Then SciPy at least can dispatch to it. Would remove the more corporate concern, too. Bits and pieces exist in torch itself and elsewhere, e.g. https://github.com/lezcano/expm/tree/master/pytorch_expm and https://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py

This is what I was thinking too, and how I'd envisioned doing things for special.

lucascolley commented 3 months ago

During Array API adoption and from conversations with SciPy devs, I've seen attempts to dispatch to underlying libraries for compute. For example: https://github.com/scipy/scipy/pull/20772. The motivation for this special casing is that the Array API standard does not container all the required APIs. Logistically, I think the standard will always lag behind what libraries offer and there will be some APIs that may never be standardized.

To come at this from a slightly different angle, if we think about the scope of the array libraries, guided by what is included in the array API standard right now, it is likely the case that some APIs will not be standardised, and deliberately so. For example, many SciPy functions, like those in scipy.stats and scipy.signal, sit firmly in the 'array-consumer' territory, rather than the array library territory (they are certainly more suited to standard extensions, rather than the core API, but even the scope of extensions doesn't extend to the whole of SciPy IMO).

The fact that modules like cupyx.scipy and torch.signal exist seems to indicate more so the usefulness/need for the functions to exist in those ecosystems, rather than the view that they belong on the array library level. As Evgeni says, a separate package which consumes the array library of choice (whether that be PyTorch, or something else) seems to make more sense to dispatch to from SciPy, rather than trying to get every compiled-code part of SciPy merged into (an) alternative array librar(y/ies).


In any case, I think that where possible, the focus should be firstly on consuming the standard API, and only extending support / switching from NumPy to another library when compiled code forces our hand and the potential gain is judged worth it.

But maybe I don't understand the proposal - is it broader than cases like https://github.com/scipy/scipy/pull/20772, which mosty tries to dispatch to existing functions in special libraries if they have a direct equivalent? What about cases like https://github.com/scipy/scipy/issues/20544? There, we are not just dispatching calls of entire scipy.stats functions to other libraries; rather, we're replacing NumPy API calls within scipy.stats functions with array API standard calls. Would this proposal suggest replacing NumPy API calls with PyTorch calls, instead?

If so, that seems pretty drastic, but maybe that is just a knee-jerk reaction from me. There does exist sentiment among some NumPy users that NumPy is fine, and they do not want to invest any time in learning to use another array library (or indeed, learning to develop for another array library, cf. https://github.com/scipy/scipy/issues/18286#issuecomment-2027886737). Even if the SciPy API remained the same, but made PyTorch a required runtime dependency, I think we would get some negative backlash. Also that the "PyTorch cpu wheel is ~ 183 MB, which is much bigger than NumPy".


Array API Standard covers Jax, Dask, and all future Array libraries that adopt the standard. Going with PyTorch as the core compute layer would reduce coverage.

This does sound like we are discussing moving away from the standard altogether, which I don't think makes sense. I think the wider coverage is the defining point of the push, as while PyTorch may provide a lot of NumPy-user-wishes right now (GPU, other nice things that you have already mentioned), it may not in the future (as wishes develop/change). The "all future Array libraries that adopt the standard" is really the star of the show IMO.

As above, moving from NumPy to PyTorch would be a separate discussion (default fallback implementation vs. dispatching to known xp-native-modules), but from my perspective that seems like a bridge to cross at rather a later point.

betatim commented 3 months ago

An alternative proposal is to use dlpack to do a zero copy transfer to PyTorch, use the PyTorch API for compute, and use dlpack transfer back to the original array container.

I should know but I don't: could you use this to solve the problem of "I want my (array consuming) library code to work with user inputs that are cupy, numpy, pytorch, etc without having to write code that contains lots of if is_cupy(x): .. elif: is_dask(x): ... elif is_numpy(x):...?

I think it is possible. You'd have a call at the start of your library function that performs the dlpack based mode, then exclusively use torch.foo(x_torch) calls in your function and at the end move back to the original input array type via dlpack. Is that about right?

lucascolley commented 3 months ago

could you use this to solve the problem of "I want my (array consuming) library code to work with user inputs that are cupy, numpy, pytorch, etc without having to write code that contains lots of if is_cupy(x): .. elif: is_dask(x): ... elif is_numpy(x):...?

I think so, yeah. At the cost of PyTorch becoming a required runtime dependency, contributors having to use the PyTorch API, and becoming reliant on PyTorch to implement support for new devices etc.

The problem it doesn't solve is "I want to keep everything native to my array library of choice (because of some unique feature), wherever only basic/fundamental functions are needed". I think the standard is needed for that.