Closed lightyrs closed 2 years ago
I had the same error, but it seems not be because of tensorflow verison. https://stackoverflow.com/questions/69937218/attributeerror-module-jaxlib-xla-extension-has-no-attribute-pmapfunction/69937986#69937986
I solved the error by downgrading dm-haiku
to 0.0.5 as suggested in the above stackoverflow answer.
You should change this line https://github.com/kingoflolz/mesh-transformer-jax/blob/master/requirements.txt#L8
to the following
dm-haiku==0.0.5
Solved in #151
Running the colab notebook now with the latest requirements.txt (
tensorflow-cpu~=2.6.0
) gives the following error when running thenetwork = CausalTransformer(params)
line.AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'
Full Trace