jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.1k stars 2.76k forks source link

How can I implement a custom operations/class instance with stateful resource? #22412

Open MoFHeka opened 2 months ago

MoFHeka commented 2 months ago

For example, there is a lookup table backend, and this operation need to get the table handle to reach it. This feature is implemented in tensorflow by tf.resource.

Can jax.pure_callback and jax.experimental.io_callback meet my requirements and have good enough performance, such as zero-copy Jax array passed into custom C++ lib?

Please:

justinjfu commented 2 months ago

Yes, jax.experimental.io_callback would be the recommended way to implement this. You may also want to look at jax.dlpack to avoid copying arrays and using the underlying buffer directly.

MoFHeka commented 2 months ago

@justinjfu What is the right way to achieve optimal performance using jax.experimental.io_callback and jax.dlpack? Are there any examples you can refer to?

superbobry commented 2 months ago

All callbacks copy inputs to the host and copy outputs to the device, so there is at least two copies if you are running on GPU/TPU.

On CPU there is a single copy of callback outputs into the pre-allocated XLA output buffers. I have a prototype which allows writing zero-copy CPU callbacks for #20701, but it is not quite ready yet.

That said, it sounds like you are asking about something else. Would the table handle be shared by all invocations of the callback? How do you currently pass it in?

MoFHeka commented 2 months ago

@superbobry All callbacks copy twice sounds unacceptable. We're migrating code from TensorFlow into Jax. And we found custom operator in Jax was unable to passed a resource handle input like Tensorflow ResourceBase. Now we have several custom lookup table implementation(including high performance GPU table), we usually initialize them as a tf.resource object just like tf.LookupTable. And then we pass these to tf.resource custom operator, such as lookup, insert, export. But I found this path is impossible when using Jax?

superbobry commented 2 months ago

Can you clarify if you want all invocations of the callback to use the same resource?

MoFHeka commented 2 months ago

Yes, of course. This resource is a hash table, and we need to use it when forwards and backwards in every training steps. We lookup huge size embedding from this hash table when forwards, and insert the updated embedding when backwards. @superbobry Any good idea?

MoFHeka commented 2 months ago

@superbobry Any progress?

superbobry commented 2 months ago

I'm afraid the problem you are trying to solve is still not quite clear to me. I understand the use-case, but I cannot advise unless you share the code (even if a sketch) of the JAX part of the custom operation.

From your description so far it sound like you can probably pass the pointer to the hash table via backend_config=?

MoFHeka commented 2 months ago

@superbobry Here's the pseudo-code

my_table_ptr = pybind11(new GPUHashTable())
lookup_fun = pybind11(GPUHashTableLookup(void* table_ptr, K* ids))
insert_fun = pybind11(GPUHashTableInsert(void* table_ptr, K* ids, V* emb_value))

def forward(ids):
    return lookup_fun(my_table_ptr, ids)

def backward(ids, emb_value):
    return insert_fun(my_table_ptr, ids, emb_value)

This requirement is a little like Extending TorchScript with Custom C++ Classes, which allow passing a custom class instance. The my_table_ptr is the same table throughout the whole training. It seems that couldn't pass a C++ instance which created in runtime through ffi backend_config? Or I could pass the C++ class instance address through backend_config?

MoFHeka commented 2 months ago

I found this document and I would try it later. https://jax--22095.org.readthedocs.build/en/22095/ffi/ffi.html