openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.65k stars 424 forks source link

How to catch XlaRuntimeError in python #23

Open vinayburugu opened 1 year ago

vinayburugu commented 1 year ago

Trying to catch XlaRuntimeError in Python. But the error is not visible and the module that defined this exception is also not visible. We are unable to import the exception. The presumed module that defined XlaRuntimeError is tensorflow.compiler.xla.xla_client.

Also, it seems currently, there is no way to import the exceptions thrown by XLA in python. The expectation is to have flexibility for the front end frameworks to import exceptions thrown by XLA. eg. from xla.exceptions import *

hawkinsp commented 1 year ago

Currently the XLA Python bindings are something of an internal detail of JAX and they aren't exposed at all as part of TensorFlow. It's not clear whether they will remain part of XLA itself or move into JAX; that is to be determined.

Can you say more about the use case? That might help answer the question about public APIs.

vinayburugu commented 1 year ago

Understood. Here is the usecase : We would like to handle exceptions arising from XLA with custom messages in the framework. But XlaRuntimeError is not visible in tensorflow python layer. What are the steps needed to make it visible in tensorflow?