google-research / t5x

Apache License 2.0
2.58k stars 296 forks source link

module 'jax' has no attribute 'tree' #1533

Open jntdst opened 3 months ago

jntdst commented 3 months ago

---> 25 utils.log_model_info(log_file, 26 train_state_initializer.global_train_state_shape, 27 partitioner)

File /kaggle/working/t5x/t5x/utils.py:1391, in log_model_info(log_file, full_train_state, partitioner) 1387 return 1389 state_dict = full_train_state.state_dict() 1390 total_num_params = jax.tree_util.tree_reduce( -> 1391 np.add, jax.tree.map(np.size, state_dict['target']) 1392 ) 1394 logical_axes = partitioner.get_logical_axes(full_train_state).state_dict() 1396 mesh_axes = jax.tree.map( 1397 lambda x: tuple(x) if x is not None else None, 1398 partitioner.get_mesh_axes(full_train_state).state_dict(), 1399 )

File /opt/conda/lib/python3.10/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr..getattr(name) 51 warnings.warn(message, DeprecationWarning, stacklevel=2) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax' has no attribute 'tree'

s1ddok commented 1 week ago

same here