google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

Currently the test ReshapeTest.test_reshape_convert and Jax2TfTest #623

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

Currently the test ReshapeTest.test_reshape_convert and Jax2TfTest when running on TPU actually runs the converted TF code on CPU. I have verified that this is because there are missing build dependencies that result in TF not seeing the TPUs.

The fix I propose here is just to unblock the jax2tf migration to native lowering (see go/jax2tf-native-migration). This fix is unsafe in general, but it is Ok for this test. Somebody may want to fix the test to use TPUs even for TF.