Open MoFHeka opened 2 months ago
Yes, jax.experimental.io_callback
would be the recommended way to implement this. You may also want to look at jax.dlpack
to avoid copying arrays and using the underlying buffer directly.
@justinjfu What is the right way to achieve optimal performance using jax.experimental.io_callback and jax.dlpack? Are there any examples you can refer to?
All callbacks copy inputs to the host and copy outputs to the device, so there is at least two copies if you are running on GPU/TPU.
On CPU there is a single copy of callback outputs into the pre-allocated XLA output buffers. I have a prototype which allows writing zero-copy CPU callbacks for #20701, but it is not quite ready yet.
That said, it sounds like you are asking about something else. Would the table handle be shared by all invocations of the callback? How do you currently pass it in?
@superbobry All callbacks copy twice sounds unacceptable. We're migrating code from TensorFlow into Jax. And we found custom operator in Jax was unable to passed a resource handle input like Tensorflow ResourceBase. Now we have several custom lookup table implementation(including high performance GPU table), we usually initialize them as a tf.resource object just like tf.LookupTable. And then we pass these to tf.resource custom operator, such as lookup, insert, export. But I found this path is impossible when using Jax?
Can you clarify if you want all invocations of the callback to use the same resource?
Yes, of course. This resource is a hash table, and we need to use it when forwards and backwards in every training steps. We lookup huge size embedding from this hash table when forwards, and insert the updated embedding when backwards. @superbobry Any good idea?
@superbobry Any progress?
I'm afraid the problem you are trying to solve is still not quite clear to me. I understand the use-case, but I cannot advise unless you share the code (even if a sketch) of the JAX part of the custom operation.
From your description so far it sound like you can probably pass the pointer to the hash table via backend_config=
?
@superbobry Here's the pseudo-code
my_table_ptr = pybind11(new GPUHashTable())
lookup_fun = pybind11(GPUHashTableLookup(void* table_ptr, K* ids))
insert_fun = pybind11(GPUHashTableInsert(void* table_ptr, K* ids, V* emb_value))
def forward(ids):
return lookup_fun(my_table_ptr, ids)
def backward(ids, emb_value):
return insert_fun(my_table_ptr, ids, emb_value)
This requirement is a little like Extending TorchScript with Custom C++ Classes, which allow passing a custom class instance. The my_table_ptr is the same table throughout the whole training. It seems that couldn't pass a C++ instance which created in runtime through ffi backend_config? Or I could pass the C++ class instance address through backend_config?
I found this document and I would try it later. https://jax--22095.org.readthedocs.build/en/22095/ffi/ffi.html
For example, there is a lookup table backend, and this operation need to get the table handle to reach it. This feature is implemented in tensorflow by tf.resource.
Can jax.pure_callback and jax.experimental.io_callback meet my requirements and have good enough performance, such as zero-copy Jax array passed into custom C++ lib?
Please: