Closed Zantares closed 4 months ago
I'm guessing this relates to wanting DLPack support in JAX, not necessarily to XLA itself?
@skye and @jyingl3 are actually looking at this. I believe their plan is to simplify JAX so it does not call this API, see https://github.com/google/jax/pull/17941
If this question doesn't pertain to JAX, can you clarify which DLPack support you mean?
Like Peter says, we don't believe this API is necessary to implement DLPack. We think you should be able to use AcquireExternalReference
instead:
https://github.com/openxla/xla/blob/59ca1209b2d207b9bae0920a799eee801e0f9148/xla/pjrt/pjrt_client.h#L967-L984
which is implemented in the PJRT C API.
The main difference between AcquireExternalReference
and ReleaseDeviceMemoryOwnership
is that the latter makes it an error to use the released buffer with the original client. We didn't think it was worth having a separate API where the only difference is error checking, but let us know if you disagree.
I believe @jyingl3 is also working on CreateViewOfDeviceBuffer
support:
https://github.com/openxla/xla/blob/59ca1209b2d207b9bae0920a799eee801e0f9148/xla/pjrt/pjrt_c_api_client.h#L269-L275
This lets you implement the opposite direction of DLPack (accepting an external buffer, vs. the external framework using a buffer from the original client).
Like Peter says, we'll make sure JAX uses just these APIs, so they should be all you need to interface with JAX in particular.
Thanks for the answers, they are really useful.
I'm guessing this relates to wanting DLPack support in JAX, not necessarily to XLA itself?
@skye and @jyingl3 are actually looking at this. I believe their plan is to simplify JAX so it does not call this API, see google/jax#17941
If this question doesn't pertain to JAX, can you clarify which DLPack support you mean?
Yes we want DL support in JAX, and I asked this question because have seen most changes of the code are happening in XLA.
The reason why I mention function ReleaseDeviceMemoryOwnership()
is because we had a simple experiment in JAX to test DLPack, then we fell into the assert of ReleaseDeviceMemoryOwnership()
. I have seen that new codes related to DLPack are merged to main branch, we will try it and try to understand the usage of AcquireExternalReference()
. I will give the feedback in a few days later so please keep this issue for a while, thanks!
The missing APIs are all implemented and the extension worked well with internal patch after rebased to newest XLA, but I still have a question related to JAX here: Though new device can be added to JAX/XLA through PJRT, there's no extended device path supporting in DLPack path now. That's why internal patch is needed here, it looks like as below:
@@ -254,6 +256,12 @@ StatusOr<PjRtDevice*> DeviceForDLDevice(const PjRtClient* cpu_client,
}
TF_RET_CHECK(gpu_client->platform_id() == RocmId());
return gpu_client->LookupAddressableDevice(context.device_id);
+ case kDLOneAPI:
+ if (gpu_client == nullptr)
+ return InvalidArgument(
+ "DLPack tensor is on extended device, but no backend was provided.");
+ return gpu_client->LookupAddressableDevice(context.device_id);
default:
return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type);
And new device named XPU
(or other allowed name) is needed to be recognized in XLA:
--- a/xla/pjrt/pjrt_compiler.h
+++ b/xla/pjrt/pjrt_compiler.h
@@ -48,6 +48,10 @@ inline const char* TpuName() {
static constexpr char kTpuName[] = "tpu";
return kTpuName;
}
+inline const char* XpuName() {
+ static constexpr char kXpuName[] = "xpu";
+ return kXpuName;
+}
inline PjRtPlatformId CpuId() {
static const PjRtPlatformId kCpuId = tsl::Fingerprint64(CpuName());
return kCpuId;
@@ -64,6 +68,10 @@ inline PjRtPlatformId TpuId() {
static const PjRtPlatformId kTpuId = tsl::Fingerprint64(TpuName());
return kTpuId;
}
+inline PjRtPlatformId XpuId() {
+ static const PjRtPlatformId kXpuId = tsl::Fingerprint64(XpuName());
+ return kXpuId;
+}
Can I direct submit this PR to XLA? Or any concern is here?
After integrated extension for OpenXLA through PJRT C API, we found that DLPack is not supported by PJRT yet: https://github.com/openxla/xla/blob/59ca1209b2d207b9bae0920a799eee801e0f9148/xla/pjrt/pjrt_c_api_client.h#L416-L420
Does the community have any plan to support it since some users are interested in this feature?