poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'optimization_barrier' #239

Open mwitiderrick opened 2 years ago

mwitiderrick commented 2 years ago

import elegy as eg results in the above error on Colab

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-5-cdd8b33baa61>](https://localhost:8080/#) in <module>()
      2 import jax
      3 
----> 4 import elegy as eg
      5 
      6 class MLP(eg.Module):

6 frames
[/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/jax2tf.py](https://localhost:8080/#) in <module>()
   2388                     extra_name_stack="checkpoint")
   2389 
-> 2390 tf_impl[lax_control_flow.optimization_barrier_p] = tfxla.optimization_barrier
   2391 
   2392 def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:

AttributeError: module 'tensorflow.compiler.tf2xla.python.xla' has no attribute 'optimization_barrier'