Closed atong01 closed 1 year ago
Thank you, Alexander! It seems like flax just stopped working on google colab. Simply running
import jax
!pip install --quiet flax
import flax
yields the error you have.
I solved it for now by downgrading Flax to 0.6.4.
Very cool, love the update!
Two notes:
ModuleNotFoundError: No module named 'flax'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last) 5 frames /usr/local/lib/python3.8/dist-packages/flax/core/meta.py in Partitioned() 263 return self.replace(names=tuple(names)) 264 --> 265 def get_partition_spec(self) -> jax.sharding.PartitionSpec: 266 """Returns the
Partitionspec
for this partitioned value.""" 267 return jax.sharding.PartitionSpec(*self.names)AttributeError: module 'jax.sharding' has no attribute 'PartitionSpec'