Open DiffeoInvariant opened 3 years ago
My concern with 1 is that it doesn't (on its own) allow comparing JAX backends to those we've already written and that the JIT might be very slow (based on the times you've reported in CRIKit experiments).
Regarding the device array handoff, there are few places in the source that reference https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html (git grep cuda_array_interface
). My understanding is that nothing comparable exists for JAX.
So the reason the CRIKit compiles are slow is that the compilation has to unroll a large Python loop. No such issue should be present in libCEED if we write a pure JAX backend, especially since most of the operations can be expressed as jnp.einsum
s.
JAX does support __cuda__array_interface__
as of 2021 (https://github.com/google/jax/pull/2133) so if you already have code to deal with that, JAX should be usable. The PR for it was merged on Dec 31, 2020 though so it probably wasn't there when you last looked
Ah, so we need to better use JAX looping primitives in CRIKit. CeedElemRestriction transpose does accumulation so we'd need a way to avoid conflicting overlapping writes (atomics are faster, though non-deterministic; libCEED docs classify backends by determinism https://libceed.readthedocs.io/en/latest/gettingstarted/#backends).
Has this been addressed in some later PR?
The reverse direction (importing arrays into JAX) is yet to come.
As to why we can't use the JAX loop primitives in CRIKit right now: we need to be indexing into a Python list, but the loop counter variable has to be a JAX tracer object, which doesn't implement the __index__
method (and that list can't be converted to a single array because it would have to be a staggered array, and JAX doesn't support that; also allocating enough memory to pad it out and copy the data into one big array would be prohibitively expensive for large grids). As I hinted to in the recent CRIKit MR related to that, perhaps this will be doable in the future with a JAX update, but I'm honestly not even certain what the right semantics would be for this operation, so I wouldn't expect to be able to soon. I have another idea for how to solve it that's totally unrelated to this though; we'll see if it works soon enough.
Anyway, I just looked through the JAX PRs and it doesn't look like the reverse direction has happened yet. I'll look into doing it myself though, it doesn't look like there's all that much code that would have to change, and I can probably take some hints from the numba codebase.
Another possibility (?) would be to use DLPack, which JAX supports. Not sure if that will work on GPU, but if it does, that would be easier than using __cuda__array_interface__
Sounds great. PETSc has DLPack support and I think it's at a level that makes sense for libCEED too.
So where would be the natural place to put the code for this? Would we add a new pair of functions to ceed.h
, say int CeedVectorToDLPack(CeedVector, DLManagedTensor **)
and int CeedVectorFromDLPack(CeedVector *, DLManagedTensor *)
and implement them for every backend (and thus implement it in Python through the ffi), or something else?
EDIT: after taking a closer look at ceed.h, this makes more sense to me:
#ifdef CEED_DLPACK
#include <dlpack.h>
CEED_EXTERN int CeedVectorTakeFromDLPack(CeedVector vec, DLManagedTensor *dl_tensor,
CeedMemType mem_type);
CEED_EXTERN int CeedVectorToDLPack(CeedVector vec, DLManagedTensor **dl_tensor,
CeedMemType mem_type);
#endif
Open to suggestions though. Development is going in https://github.com/CEED/libCEED/tree/emily/dlpack for now
I'm not big on hiding includes behind ifdefs at compile time.
I'm perhaps not understanding, but does this need to be in the C interface? I thought this was for passing data back and forth in Python?
If it has to be in the C interface, perhaps a separate header and file, like ceed/cuda.h and ceed/hip.h
On Mon, Jun 7, 2021, at 5:39 PM, Emily Jakobs wrote:
So where would be the natural place to put the code for this? Would we add a new pair of functions to
ceed.h
, sayint CeedVectorToDLPack(CeedVector, DLManagedTensor **)
andint CeedVectorFromDLPack(CeedVector *, DLManagedTensor *)
and implement them for every backend, or something else?— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/CEED/libCEED/issues/781#issuecomment-856330408, or unsubscribe https://github.com/notifications/unsubscribe-auth/AF62K5MYUAG4WQ6ZTCAGGZLTRVKCFANCNFSM46IN3P2Q.
Well it's about getting data to and from C and Python (so that you can use JAX qfunctions with any CEED backend). I suppose it doesn't have to be in the C API, but there has to be a function we can call from the cffi in the Python implementation, and it should be the same function regardless of what backend is in use. I think this would be a useful function to have in the main C interface though because it would more easily enable users to use, say, C++ TensorFlow code or any other C++ ML code as a qfunction instead of just JAX.
That makes sense. I'd be in favor of adding ceed/dlpack.h and interface/ceed-dlpack.c if we add this to the C interface in anticipation of future flexibility.
On Mon, Jun 7, 2021, at 6:07 PM, Emily Jakobs wrote:
Well it's about getting data to and from C and Python (so that you can use JAX qfunctions with any CEED backend). I suppose it doesn't have to be in the C API, but there has to be a function we can call from the cffi in the Python implementation, and it should be the same function regardless of what backend is in use. I think this would be a useful function to have in the main C interface though because it would more easily enable users to use, say, C++ TensorFlow code or any other C++ ML code as a qfunction instead of just JAX.
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/CEED/libCEED/issues/781#issuecomment-856342090, or unsubscribe https://github.com/notifications/unsubscribe-auth/AF62K5NINDUHK2CKD57QGF3TRVNKPANCNFSM46IN3P2Q.
Is there currently a good way to determine which backend a CeedVector
belongs to? (so I can ensure that the memory coming in from DLPack or going to it is in the right place, e.g. that we return an error or otherwise do the right thing if you try to give a CUDA backend a CPU or ROCM array)
You can give a host array to a GPU backend just fine. CPU backends getting a device array will error.
See the PETSc example for an example of querying the backend preferred memory type and setting that as the handoff memory type.
We make the assumption that users will hand off the correct type of device array for the backend.
So should I assume that the user has downloaded dlpack.h
and set the environment variable DLPACK_DIR
(or something like that) for makefile purposes (e.g. like seems to be done with OCCA_DIR
, HIP_DIR
, etc.)? If not, where should I download dlpack.h
to?
By the way, in case you want to see an example of using ceed/dlpack.h
, check out tests/t124-vector.c
, which was my best guess as to what I should name the host memory-only test (where I create one CeedVector
, set its values, then transfer its data pointer to another CeedVector
via an intermediate DLManagedTensor
). Let me know if I should name it something different though
Oh, after thinking about it a bit more, we'll also need functions to transform the input/output context data for a QFunction to/from DLPack, right? (So that we can pass the input fields as DLManagedTensor
s to JAX then pass the DLManagedTensor
s representing the output data to the output fields) If so, where is the data for those fields located and how would I get it? As far as I can tell, I can use CeedQFunctionGetFields
to get a pointer to an array of structs containing the name, size, and eval mode of each field, but I don't see how to get the pointers to the actual data for those fields.
Hmm, the PETSc interface to DLPack is pure Cython (without mentioning dlpack.h
). https://gitlab.com/petsc/petsc/-/commit/6662888c9dbb6e9928870ec82930c544853c0c49
I don't know if it's better to include the header or do it dynamically in Python using cffi/ctypes.
The contexts are plain data (no pointers) of specified size so they can just be copied. It's technically unnecessary in languages with closures or other dynamic way to construct functions -- enabling them to be parametrized without depending on global variables.
I don't think I understand your envisioned outcome.
If we go the Python route, I would actually change python/ceed_qfunction.py
. I'd first make a subclass of Python interface QFunction objects that can use native Python QFunctions, like we did in rust/libceed/src/qfunction.rs
. From there, you could also make JAX QFunctions in Python and use them with all backends.
If we go the C route, we'd make a brand new backend that delegates back to the different backends. We'd have something like /cpu/self/jax/serial
, /cpu/self/jax/blocked
, /gpu/cuda/jax
, and /gpu/hip/jax
. It would end up looking sort of like backends/memcheck
, though I suppose the GPU variants would be trickier.
I feel like the Python route might be easier (I've wanted to add native Python QFunctions for a while but haven't had the time)
so my envisioned outcome is essentially what you're describing with python/ceed_qfunction.py
@jeremylt. However, in order to do that, we'll need some way to get the input data into JAX DeviceArray
instances, and some way to get the output DeviceArray
(s) back into libCEED data structures. In order to do that, we need an interface between libCEED and DLPack, and the only way to do that (AFAIK) is to either write C functions for this and call them through CFFI and/or ctypes (which seems most consistent with how the rest of the Python interface is implemented), as is done here: https://github.com/dmlc/dlpack/blob/main/apps/from_numpy/main.py or to write the implementation in Cython like PETSc does.
I think writing QFunctions in Python/JAX is what we want. I'd like to preserve the ability to use QFunctions that were written in C from Python. I'm not sure the value of a straight C interface. Looking at dlpack.h
, I would not object to a ceed/dlpack.h
that provides the conversion/accessors. It looks like everyone is doing the equivalent of vendoring dlpack.h
.
Writing a /gpu/cuda/jax
backend would have limited end-user value, though it could be an interesting test of JAX fusion/expressiveness.
@jedbrown just to be entirely unambiguous about what you're thinking about, are you thinking of a ceed/dlpack.h
that looks something like this:
#ifndef _ceed_dlpack_h
#define _ceed_dlpack_h
#include <ceed/ceed.h>
#include <dlpack.h>
CEED_EXTERN int CeedVectorTakeFromDLPack(Ceed ceed,
CeedVector vec,
DLManagedTensor *dl_tensor,
CeedCopyMode copy_mode);
CEED_EXTERN int CeedVectorToDLPack(Ceed ceed,
CeedVector vec,
CeedMemType dl_mem_type,
DLManagedTensor **dl_tensor);
#endif
(perhaps with functions related a QFunction instead of a Vector), or were you thinking of copy-pasting the contents ofdlpack.h
into ceed/dlpack.h
as well? And if you're thinking of the former, how should we package dlpack.h
? (i.e. should users have to download it themselves and set DLPACK_DIR
when compiling, or something else?)
Yes, the above functions look about right.
I think we should do what others do and "vendor" the header from upstream -- copy it into the libceed repository. That'll allow us to implement those public interfaces without configuration options. One choice would be to keep the header private (don't install it) and only include it in interface/ceed-vector-dlpack.c
, requiring callers to include their own dlpack.h
from upstream (or set flags so #include <dlpack.h>
finds the header) when they include ceed/dlpack.h
. Is it possible to do the Python/JAX handoff with DLManagedTensor
being an opaque pointer?
I'm testing that right now in python/tests/test_vector.py
test number 124 (tests/t124-vector.c
tests the C functions in host memory mode) -- I'm running into some issues building the FFI (via pip install .
) though; for some reason the DLPack types aren't being recognized by cffi, so I get a whole bunch of errors from the CFFI-generated code, like
build/temp.linux-x86_64-3.8/_ceed_cffi.c:1336:47: error: unknown type name ‘DLDataType’
1336 | static void _cffi_checkfld_typedef_DLDataType(DLDataType *p)
FWIW, DLDataType
is defined (now in ceed/dlpack.h
-- we can of course move that to interface/ceed-dlpack.c
if we want) as
typedef struct {
/*!
* \brief Type code of base types.
* We keep it uint8_t instead of DLDataTypeCode for minimal memory
* footprint, but the value should be one of DLDataTypeCode enum values.
* */
uint8_t code;
/*!
* \brief Number of bits, common choices are 8, 16, 32.
*/
uint8_t bits;
/*! \brief Number of lanes in the type, used for vector types. */
uint16_t lanes;
} DLDataType;
and the same happens with other DLPack types. Any idea what I might be doing wrong? (I should have everything updated in the remote so if you want to see the errors for yourself, cloning my branch and pip installing should work)
EDIT: ah, figured it out, had to include ceed/dlpack.h
in ffibuilder.set_source()
.
So it turns out that tensorflow uses pybind11, so the returned DLManagedTensor
from jax.dlpack.to_dlpack()
is really a PyCapsule
holding the struct. I'm not sure yet how to correctly unpack that without linking Python, but I'm sure there's some way to do it with appropriate ctypes
usage or something like that
EDIT: figured this one out, use ctypes.pythonapi.PyCapsule_GetPointer()
to get the address as an int and ffi.cast()
to cast it to a pointer of the correct type
But I thought you were inside of Python?
yes, but jax calls out to XLA, which is a C++ library (and a part of tensorflow). See https://github.com/google/jax/blob/master/jax/_src/dlpack.py#L43 and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/dlpack.cc#L250
But why can't the conversion to dlpack tensors occur in Python? When a QFunction is created, a user provides a pointer to a function (or the interface takes the user's function and converts it into an appropriate pointer) (note, this function can be written in any language) which takes in raw pointers to the context data, input array, and output arrays. Why not convert the raw array pointers to dlpack tensors inside of a Python function call?
Oh, I see what you mean; no reason it couldn't happen inside of Python, I just don't see what advantage that would confer over also having the code in C, especially now because I figured out how to get around that problem I was just describing (ctypes.pythonapi.PyCapsule_GetPointer()
to get the address as an int and ffi.cast()
to cast it to a pointer of the correct type).
So I could re-write the C portion of this code in Python using CFFI and/or ctypes if you have a good reason for it to be there instead of in C, but as far as I can see literally the only difference would be that the code would then only be accessible through the Python API as opposed to potentially being available in every libCEED API. Either way, the C code I have works to get a read-only array from JAX. Still working on a read-write compatible one, for some reason weird things happen with that (if you pass True
as the second parameter of jax.dlpack.to_dlpack()
)
Unrelated to the comments immediately above this one, I do have one question of semantics: if the incoming DLPack DLManagedTensor
is carrying a float32
s (common with JAX unless you run jax.config.update('jax_enable_x64', True)
or set JAX_ENABLE_X64=True
in the environment) or another datatype that is smaller than CeedScalar
(but sizeof(CeedScalar) == k * sizeof(the_dl_managed_tensor_datatype)
for some integer k
), should we cast the elements to double
or return an error? Right now, I have the function return an error if the incoming datatype doesn't have the same size as CeedScalar
, but that behavior could of course be changed.
Erroring is the right behavior in case of precision mismatch. See also #778
As for where the conversion code lives, we just want to keep build-time configuration as simple as possible. Python is fully dynamic, but the cffi code is basically equivalent to vendoring dlpack.h
in the C code.
Just posting an update here in case anyone has any ideas about how to get around the issue I'm working through; the main problem I'm dealing with is that jax.dlpack.from_dlpack()
expects a PyCapsule
containing the pointer to the DLManagedTensor
(see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/dlpack.cc#L322 ; despite the fact that the TF devs immediately static_cast
the capsule to DLManagedTensor *
, if you try to pass a raw DLManagedTensor *
to jax.dlpack.to_dlpack()
, it throws an exception complaining about a type mismatch). The Python function libceed.Vector.to_dlpack()
looks like this:
def to_dlpack(self, mem_type, return_capsule=False):
# return a PyCapsule if return_capsule is True
dl_tensor = ffi.new("DLManagedTensor *")
ierr = lib.CeedVectorToDLPack(self._ceed._pointer[0],
self._pointer[0], mem_type,
dl_tensor)
self._ceed._check_error(ierr)
if return_capsule:
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.py_object, ctypes.c_char_p]
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
# line below this comment causes segfault --- why??
dl_tensor = ctypes.pythonapi.PyCapsule_New(dl_tensor,
b'dltensor'
)
return dl_tensor
but it segfaults for some reason I can't yet figure out. An alternative that should "just work" would be to write a C or C++ function that creates the PyCapsule
, but that would require linking Python (which is apparently not a good idea when you're using CFFI, at least according to the CFFI docs), so if anyone has a better idea of how to get around this problem, please let me know
@caidao22 Do you have experience/suggestions for this issue connecting vectors exposed via DLPack with JAX?
Just posting an update here in case anyone has any ideas about how to get around the issue I'm working through; the main problem I'm dealing with is that
jax.dlpack.from_dlpack()
expects aPyCapsule
containing the pointer to theDLManagedTensor
(see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/dlpack.cc#L322 ; despite the fact that the TF devs immediatelystatic_cast
the capsule toDLManagedTensor *
, if you try to pass a rawDLManagedTensor *
tojax.dlpack.to_dlpack()
, it throws an exception complaining about a type mismatch). The Python functionlibceed.Vector.to_dlpack()
looks like this:def to_dlpack(self, mem_type, return_capsule=False): # return a PyCapsule if return_capsule is True dl_tensor = ffi.new("DLManagedTensor *") ierr = lib.CeedVectorToDLPack(self._ceed._pointer[0], self._pointer[0], mem_type, dl_tensor) self._ceed._check_error(ierr) if return_capsule: ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.py_object, ctypes.c_char_p] ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object # line below this comment causes segfault --- why?? dl_tensor = ctypes.pythonapi.PyCapsule_New(dl_tensor, b'dltensor' ) return dl_tensor
but it segfaults for some reason I can't yet figure out. An alternative that should "just work" would be to write a C or C++ function that creates the
PyCapsule
, but that would require linking Python (which is apparently not a good idea when you're using CFFI, at least according to the CFFI docs), so if anyone has a better idea of how to get around this problem, please let me know
Shouldn’t PyCapsule_New take three input arguments with the third one being a destructor?
The destructor is optional (see https://docs.python.org/3/c-api/capsule.html), and although omitting it might cause a memory leak, it shouldn't cause PyCapsule_New
to segfault. I was planning on adding an appropriate destructor once I got the basic hand-off working. For what it's worth, if you change the PyCapsule_New.restype
to ctypes.c_void_p
, it returns a Python int
holding the returned pointer, which leads to the following exception in jax.dlpack.from_dlpack()
:
E TypeError: dlpack_managed_tensor_to_buffer(): incompatible function arguments. The following argument types are supported:
E 1. (arg0: capsule, arg1: jaxlib.xla_extension.Client) -> StatusOr[xla::PyBuffer::pyobject]
E
E Invoked with: 140329718942960, <jaxlib.xla_extension.Client object at 0x7fa10ee22930>
What if you comment out these two lines: ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.py_object, ctypes.c_char_p] ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
I would also try to pass NULL as the destructor before the destructor is implemented.
If you comment out those two lines, then PyCapsule_New
throws a TypeError
saying that it can't convert argument number 1 -- you have to tell most of the functions in the ctypes.pythonapi
module what input and output types to use before you call them for some reason (the only existing docs on this as far as I can tell are the source itself and some stackexchange questions). Good point about making the unimplemented dtor explicit though, I've changed that
In PETSc I used PyCapsule_New through cython header Python.h and did not need to set the types. Perhaps it is easier to follow
https://gitlab.com/petsc/petsc/-/blob/main/src/binding/petsc4py/src/PETSc/petscvec.pxi#L580
@jedbrown how do you feel about having a Cython module with a single function that handles the PyCapsule
s and is called inside of libceed.Vector.to_dlpack()
? On one hand, that seems like a pretty heavy dependency for one helper function in the Python API, especially given that we'd have to make sure everything plays nicely with CFFI, but on the other, I can't think of a better solution.
Agreed that it's a heavy dependency, but let's try that and once it's working, maybe we'll understand the problem well enough we can drop the Cython dependency. And if not, no big deal.
This probably won't matter too much for this cython code, but once I have this working, am I correct in assuming the next step would be to write a pair of functions like int CeedQDataToDLPack(CeedInt ndim, CeedInt *shape, const CeedScalar *const in, DLManagedTensor *dl_tensor);
and int CeedQDataFromDLPack(DLManagedTensor *dl_tensor, CeedScalar **out)
? (note that out
is not CeedScalar *const
because that would require a copy)
Just leaving an update on my progress here -- getting data from JAX into a libceed.Vector
works as expected (see tests 124, 127, and 128 in python/tests/test-1-vector.py
), but the other way around, getting data from a libceed.Vector
into JAX, does not work (see test 125 in the same file). This looks to be because somehow TensorFlow has a different representation of the DLManagedTensor (i.e. struct members have a different offsetof
in their code as opposed to ours); I can tell because I checked which DLPack version they use, and since it's 020 (and we had 050 in this code previously), I switched our dlpack.h
to match theirs, and despite passing in an appropriate capsule around a managed dlpack tensor (see int CeedPrintDLManagedTensor(DLManagedTensor *)
slash libceed.Vector.print_dlpack()
, a Python wrapper around that C function, which can be used to verify the contents of the DLManagedTensor
before passing it to JAX/TensorFLow), one test fails with
RuntimeError: Invalid argument: Unknown/unsupported DLPack device type 13717520
and sometimes instead gives
RuntimeError: Invalid argument: Number of dimensions in DLManagedTensor must be nonnegative, got -1276630432
which seems to indicate, as I mentioned above, that TensorFlow is reading in the wrong fields, in part because I know that every field in the struct I'm passing is initialized, since I'm passing this (output of the above-mentioned printing function):
struct DLManagedTensor {
DLTensor dl_tensor == {
void* data == 0x251c080;
DLContext ctx == {
DLDeviceType device_type == 1;
int device_id == 0;
};
int ndim == 1;
DLDataType dtype == {
uint8_t code == 2;
uint8_t bits == 64;
uint16_t lanes == 1;
};
int64_t *shape == [10];
int64_t *strides == NULL
uint64_t byte_offset == 0;
};
void * manager_ctx == 0x255c500;
void (*deleter)(struct DLManagedTensor * self) == 0x7f41b399955b
};
@jedbrown and I have been discussing the possibility of using JAX to write qfunctions, since it supports JIT compilation and automatic differentiation. I see several ways to go about this, and several potential roadblocks, so I'm opening this issue for discussion. First, we need to decide what sort of architecture we want -- here are a few options:
DeviceArray
instances. The major advantage of this approach would be that it's at least to some extent backend-independent (perhaps not the getting data intoDeviceArray
s part) and would require writing less Python code (i.e. not having to implement most libCEED functions in Python), but the biggest disadvantage would probably be that it's not necessarily easy to get the data into a JAX array on the device with no copying. The goal would be to avoid having to write C++ code that depends on XLA itself, since such code can really only be compiled in any reasonable manner by BazelIf any of the libCEED devs have thoughts on this or are interested in working with me on implementing it, please let me know