Currently the test ReshapeTest.test_reshape_convert and Jax2TfTest
when running on TPU actually runs the converted TF code on CPU. I have verified that this is because there are missing build dependencies that result in TF not seeing the TPUs.
The fix I propose here is just to unblock the jax2tf migration to native lowering (see go/jax2tf-native-migration). This fix is unsafe in general, but it is Ok for this test. Somebody may want to fix the test to use TPUs even for TF.
Currently the test ReshapeTest.test_reshape_convert and Jax2TfTest when running on TPU actually runs the converted TF code on CPU. I have verified that this is because there are missing build dependencies that result in TF not seeing the TPUs.
The fix I propose here is just to unblock the jax2tf migration to native lowering (see go/jax2tf-native-migration). This fix is unsafe in general, but it is Ok for this test. Somebody may want to fix the test to use TPUs even for TF.