openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.39k stars 356 forks source link

[XLA:GPU] Add command buffer custom call targets recording for legacy custom call registry API #14266

Open shawnwang18 opened 3 days ago

shawnwang18 commented 3 days ago

This PR enabling lowering TE custom call kernels to command buffer through white list, this is a temporal support as 3rd library call should register command buffer support through FFI API, will remove this temporal support when TE has finished FFI API migration.

shawnwang18 commented 2 days ago

Maybe an attribute could be added to custom-call instead? I can see this list quickly getting out of date.

This is a temporal support, the common approach should be registering through new FFI API, but TE library migration to FFI might need take some time. Will remove this code when TE's FFI API migrated.

cheshire commented 2 days ago

Maybe we could wait until that then? I'd prefer not to hardcode a large, potentially outdated, potentially colliding, list of strings.

ezhulenev commented 2 days ago

My preference is also to migrate to FFI and use a standard mechanism. Let's agree on a time line to do that, i.e. on Aug 1st (or Sept 1st) we remove this list, and TE should be ready by that time. I don't want to keep this list forever growing and encourage other folks to follow TE in a bad way.

shawnwang18 commented 2 days ago

My preference is also to migrate to FFI and use a standard mechanism. Let's agree on a time line to do that, i.e. on Aug 1st (or Sept 1st) we remove this list, and TE should be ready by that time. I don't want to keep this list forever growing and encourage other folks to follow TE in a bad way.

Will do like this. I had this temporal fix because this is the feature that is currently waiting for right now for training some largest LLM models by customers. but the FFI migration on TE still months to go, as this temporal fix is quite simple and we hope to try it.

cheshire commented 1 day ago

As a temporary workaround, maybe have a flag with comma-separated list of custom kernels which are cuda-graphable?

shawnwang18 commented 18 hours ago

As a temporary workaround, maybe have a flag with comma-separated list of custom kernels which are cuda-graphable?

Looks good idea, this can support all custom calls that uses legacy custom call registry API. I made the update according this idea.