google / objax

Apache License 2.0
768 stars 77 forks source link

Regression of JAX duck typing. #220

Open AlexeyKurakin opened 3 years ago

AlexeyKurakin commented 3 years ago

When JAX variables are used without .value with jit, gradient and other transformation it leads to error TypeError: Value Traced<...> with type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> is not a valid JAX type This is a recent regression caused by some underlying changes in JAX

AlexeyKurakin commented 3 years ago

Temporary workaround - either use .value or rollback to JAX version 0.2.10

yetiansh commented 2 years ago

Hi, I have also encountered this issue when running HuggingFace T5 examples with Jax and Flax. Is there any updates?