tensorflow / runtime

A performant and modular runtime for TensorFlow
Apache License 2.0
757 stars 124 forks source link

Make ROCM platform name consistent with tensorflow stream executor. #93

Closed rsanthanam-amd closed 2 years ago

rsanthanam-amd commented 2 years ago

This will enable the //tensorflow/compiler/xla/service/gpu:custom_call_test unit test to work with BEF Thunk on ROCM.

Thanks for your contribution! Unfortunately, tensorflow/runtime is currently not accepting contributions. Please see the Contribution Guidelines for more information.

/cc @chsigg

chsigg commented 2 years ago

Would you mind explaining briefly why/where we are string-matching the TFRT GPU platform name in XLA?

rsanthanam-amd commented 2 years ago

This enables the custom call test for BEF Thunk on ROCm.

The platform name is used as one of the keys for registry and lookup in the custom call registry.

In this case, the custom call function is registered in TF proper with the 'ROCM' platform name.

But it is retrieved in TFRT using 'ROCm' and because of this, the lookup fails.

chsigg commented 2 years ago

Rohit, I think https://github.com/tensorflow/tensorflow/commit/5cb5f52 has fixed this issue. Would you mind to confirm and close? Thanks!

rsanthanam-amd commented 2 years ago

Confirmed that the cited fix resolves the issue.