Closed Alpslee closed 1 year ago
I believe this is related to the jax version. Can you check that you have 0.4.13?
I believe this is related to the jax version. Can you check that you have 0.4.13?
I checked jax and jaxlib version, it's jax 0.4.13 and jaxlib 0.4.14. Did that cause the conflict ?
I believe this is related to the jax version. Can you check that you have 0.4.13?
fixed it, thanks !
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray AttributeError: module 'jax.interpreters.pxla' has no attribute 'ShardedDeviceArray'
Just tried to run training code: python experiments/train.py \ --config experiments/configs/train_config.py:METHOD \ --bridgedata_config experiments/configs/data_config.py:all \ --name NAME