rail-berkeley / bridge_data_v2

MIT License
121 stars 24 forks source link

AttributeError #2

Closed Alpslee closed 1 year ago

Alpslee commented 1 year ago

ArraySharded = jax.interpreters.pxla.ShardedDeviceArray AttributeError: module 'jax.interpreters.pxla' has no attribute 'ShardedDeviceArray'

Just tried to run training code: python experiments/train.py \ --config experiments/configs/train_config.py:METHOD \ --bridgedata_config experiments/configs/data_config.py:all \ --name NAME

HomerW commented 1 year ago

I believe this is related to the jax version. Can you check that you have 0.4.13?

Alpslee commented 1 year ago

I believe this is related to the jax version. Can you check that you have 0.4.13?

I checked jax and jaxlib version, it's jax 0.4.13 and jaxlib 0.4.14. Did that cause the conflict ?

Alpslee commented 1 year ago

I believe this is related to the jax version. Can you check that you have 0.4.13?

fixed it, thanks !