allenai / unified-io-2

Apache License 2.0
572 stars 27 forks source link

run demo load checkpoint error . #6

Closed kelisiya closed 9 months ago

kelisiya commented 10 months ago

When I download large-3m in my local path , run this in A100 GPU I set FULL_CKPT_PATH = './experiment/unified-io-2/large-3m'
MODEL_TYPE = "large"

TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray

FangruiZeng commented 10 months ago

same error

Luh1124 commented 10 months ago

I also meet same question with large-3m and xl-3m on A5000 GPU Ubuntu 18.04 image

chrisc36 commented 10 months ago

Looks like an issue with the checkpoints on GPUs, I will try and reproduce it.

hudbrog commented 10 months ago

Same issue on A6000

File /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:2026, in asarray(a, dtype, order)
   2024 @_wraps(np.asarray, lax_description=_ARRAY_DOC)
   2025 def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
-> 2026   lax_internal._check_user_dtype_supported(dtype, "asarray")
   2027   dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
   2028   return array(a, dtype=dtype, copy=False, order=order)

File /usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py:4812, in _check_user_dtype_supported(dtype, fun_name)
   4810   msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
   4811   msg += f" in {fun_name}" if fun_name else ""
-> 4812   raise TypeError(msg)
   4813 if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
   4814   msg = ("Explicitly requested dtype {} {} is not available, "
   4815          "and will be truncated to dtype {}. To enable more dtypes, set the "
   4816          "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
   4817          "environment variable. "
   4818          "See https://github.com/google/jax#current-gotchas for more.")

TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray
robooootx commented 10 months ago

same error

robooootx commented 10 months ago

Same issue on A6000

File /usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:2026, in asarray(a, dtype, order)
   2024 @_wraps(np.asarray, lax_description=_ARRAY_DOC)
   2025 def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Optional[str] = None) -> Array:
-> 2026   lax_internal._check_user_dtype_supported(dtype, "asarray")
   2027   dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
   2028   return array(a, dtype=dtype, copy=False, order=order)

File /usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py:4812, in _check_user_dtype_supported(dtype, fun_name)
   4810   msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
   4811   msg += f" in {fun_name}" if fun_name else ""
-> 4812   raise TypeError(msg)
   4813 if dtype is not None and np_dtype != dtypes.canonicalize_dtype(dtype):
   4814   msg = ("Explicitly requested dtype {} {} is not available, "
   4815          "and will be truncated to dtype {}. To enable more dtypes, set the "
   4816          "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell "
   4817          "environment variable. "
   4818          "See https://github.com/google/jax#current-gotchas for more.")

TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray

how to resolve?

robooootx commented 10 months ago

change dtype in t5x/examples/unified_io/t5_1_1/{model_size}.gin and t5x/examples/unified_io.config.py to float32 in i got new error:


AssertionError Traceback (most recent call last) Cell In[11], line 18 16 vocab = get_default_vocabulary() 17 partitioner = partitioning.PjitPartitioner(num_partitions=8) ---> 18 parameters, param_axes = uio_utils.get_parameters(model, FULL_CKPT_PATH, partitioner)

get_parameters(model, model_checkpoint, partitioner, rng) 83 input_shapes, input_types = get_input_spec(1) 84 if partitioner is not None: ---> 85 train_state_initializer = TrainStateInitializer( 86 optimizer_def=None, 87 init_fn=model.get_initial_variables, 88 input_shapes=input_shapes, 89 input_types=input_types, 90 partitioner=partitioner 91 ) 92 param_axes = train_state_initializer.train_state_axes.params 93 params = LegacyCheckpointManager( 94 restore_cfg=RestoreCheckpointConfig(model_checkpoint), 95 train_state_shape=train_state_initializer.global_train_state_shape, 96 partitioner=partitioner 97 ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params

unified_io_2/t5x/utils.py:958, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, model, input_types) 955 return train_state_lib.InferenceState.create(initial_variables) 957 self._partitioner = partitioner --> 958 self.global_train_state_shape = jax.eval_shape( 959 initialize_train_state, rng=jax.random.PRNGKey(0)) 961 self.train_state_axes = partitioner.get_mesh_axes( 962 self.global_train_state_shape) 963 self._initialize_train_state = initialize_train_state

File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/_src/api.py:3201, in eval_shape(fun, *args, *kwargs) 3199 wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) 3200 debug_info = pe.debug_info(fun, in_tree, True, "eval_shape") -> 3201 out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, 3202 map(shaped_abstractify, args_flat), 3203 debug_info=debug_info) 3204 out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out] 3205 return tree_unflatten(out_tree(), out)

unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:660, in abstract_eval_fun(fun, debug_info, *avals, *params) 659 def abstract_eval_fun(fun, avals, debuginfo=None, **params): --> 660 , avalsout, = trace_to_jaxpr_dynamic( 661 lu.wrap_init(fun, params), avals, debug_info) 662 assert all(isinstance(aval, AbstractValue) for aval in avals_out) 663 return avals_out

unified_io/lib/python3.9/site-packages/jax/_src/profiler.py:314, in annotate_function..wrapper(*args, kwargs) 311 @wraps(func) 312 def wrapper(args, kwargs): 313 with TraceAnnotation(name, decorator_kwargs): --> 314 return func(args, kwargs) 315 return wrapper

unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1981, in trace_to_jaxpr_dynamic(fun, in_avals, debug_info, keep_inputs) 1979 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore 1980 main.jaxpr_stack = () # type: ignore -> 1981 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( 1982 fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) 1983 del main, fun 1984 return jaxpr, out_avals, consts

unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998, in trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs, debug_info) 1996 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) 1997 intracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] -> 1998 ans = fun.call_wrapped(*intracers) 1999 out_tracers = map(trace.full_raise, ans) 2000 jaxpr, consts = frame.to_jaxpr(out_tracers)

/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, *kwargs) 164 gen = gen_static_args = out_store = None 166 try: --> 167 ans = self.f(args, dict(self.params, kwargs)) 168 except: 169 # Some transformations yield from inside context managers, so we have to 170 # interrupt them before reraising the exception. Otherwise they will only 171 # get garbage-collected at some later time, running their cleanup tasks 172 # only after this exception is handled, which can corrupt the global 173 # state. 174 while stack:

File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, *kwargs) 164 gen = gen_static_args = out_store = None 166 try: --> 167 ans = self.f(args, dict(self.params, kwargs)) 168 except: 169 # Some transformations yield from inside context managers, so we have to 170 # interrupt them before reraising the exception. Otherwise they will only 171 # get garbage-collected at some later time, running their cleanup tasks 172 # only after this exception is handled, which can corrupt the global 173 # state. 174 while stack: , in TargetSequence.__post_init__(self) 90 assert self.position_id.shape[:2] in [(1, seq_len), (bs, seq_len)] 92 assert self.modality_id.shape in [(), (1, seq_len), (bs, seq_len)] ---> 93 assert self.modality_id.dtype == jnp.int32 95 if self.target_tokens is not None: 96 assert self.target_tokens.shape == (bs, seq_len)

AssertionError:

chrisc36 commented 10 months ago

I have added the ability to load the model in float32 model, I was able to run the demo with the XL model using one A6000 and the XXL with 2 A600.

I am not sure about the assertion error but feel free to create a new issue with instructions to reproduce it.

KevinKokinda commented 9 months ago

change dtype in t5x/examples/unified_io/t5_1_1/{model_size}.gin and t5x/examples/unified_io.config.py to float32 in i got new error:

AssertionError Traceback (most recent call last) Cell In[11], line 18 16 vocab = get_default_vocabulary() 17 partitioner = partitioning.PjitPartitioner(num_partitions=8) ---> 18 parameters, param_axes = uio_utils.get_parameters(model, FULL_CKPT_PATH, partitioner)

get_parameters(model, model_checkpoint, partitioner, rng) 83 input_shapes, input_types = get_input_spec(1) 84 if partitioner is not None: ---> 85 train_state_initializer = TrainStateInitializer( 86 optimizer_def=None, 87 init_fn=model.get_initial_variables, 88 input_shapes=input_shapes, 89 input_types=input_types, 90 partitioner=partitioner 91 ) 92 param_axes = train_state_initializer.train_state_axes.params 93 params = LegacyCheckpointManager( 94 restore_cfg=RestoreCheckpointConfig(model_checkpoint), 95 train_state_shape=train_state_initializer.global_train_state_shape, 96 partitioner=partitioner 97 ).restore([model_checkpoint], RestoreCheckpointConfig(model_checkpoint)).params

unified_io_2/t5x/utils.py:958, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, model, input_types) 955 return train_state_lib.InferenceState.create(initial_variables) 957 self._partitioner = partitioner --> 958 self.global_train_state_shape = jax.eval_shape( 959 initialize_train_state, rng=jax.random.PRNGKey(0)) 961 self.train_state_axes = partitioner.get_mesh_axes( 962 self.global_train_state_shape) 963 self._initialize_train_state = initialize_train_state

File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/_src/api.py:3201, in eval_shape(fun, *args, *kwargs) 3199 wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) 3200 debug_info = pe.debug_info(fun, in_tree, True, "eval_shape") -> 3201 out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, 3202 map(shaped_abstractify, args_flat), 3203 debug_info=debug_info) 3204 out = [ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out] 3205 return tree_unflatten(out_tree(), out)

unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:660, in abstract_eval_fun(fun, debug_info, *avals, *params) 659 def abstract_eval_fun(fun, avals, debuginfo=None, **params): --> 660 , avalsout, = trace_to_jaxpr_dynamic( 661 lu.wrap_init(fun, params), avals, debug_info) 662 assert all(isinstance(aval, AbstractValue) for aval in avals_out) 663 return avals_out

unified_io/lib/python3.9/site-packages/jax/_src/profiler.py:314, in annotate_function..wrapper(*args, kwargs) 311 @wraps(func) 312 def wrapper(args, kwargs): 313 with TraceAnnotation(name, decorator_kwargs): --> 314 return func(args, kwargs) 315 return wrapper

unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1981, in trace_to_jaxpr_dynamic(fun, in_avals, debug_info, keep_inputs) 1979 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore 1980 main.jaxpr_stack = () # type: ignore -> 1981 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( 1982 fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info) 1983 del main, fun 1984 return jaxpr, out_avals, consts

unified_io/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1998, in trace_to_subjaxpr_dynamic(fun, main, in_avals, keep_inputs, debug_info) 1996 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) 1997 intracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] -> 1998 ans = fun.call_wrapped(*intracers) 1999 out_tracers = map(trace.full_raise, ans) 2000 jaxpr, consts = frame.to_jaxpr(out_tracers)

/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, *kwargs) 164 gen = gen_static_args = out_store = None 166 try: --> 167 ans = self.f(args, dict(self.params, kwargs)) 168 except: 169 # Some transformations yield from inside context managers, so we have to 170 # interrupt them before reraising the exception. Otherwise they will only 171 # get garbage-collected at some later time, running their cleanup tasks 172 # only after this exception is handled, which can corrupt the global 173 # state. 174 while stack:

File /home/pai/envs/unified_io/lib/python3.9/site-packages/jax/linear_util.py:167, in WrappedFun.call_wrapped(self, *args, *kwargs) 164 gen = gen_static_args = out_store = None 166 try: --> 167 ans = self.f(args, dict(self.params, kwargs)) 168 except: 169 # Some transformations yield from inside context managers, so we have to 170 # interrupt them before reraising the exception. Otherwise they will only 171 # get garbage-collected at some later time, running their cleanup tasks 172 # only after this exception is handled, which can corrupt the global 173 # state. 174 while stack: , in TargetSequence.post_init(self) 90 assert self.position_id.shape[:2] in [(1, seq_len), (bs, seq_len)] 92 assert self.modality_id.shape in [(), (1, seq_len), (bs, seq_len)] ---> 93 assert self.modality_id.dtype == jnp.int32 95 if self.target_tokens is not None: 96 assert self.target_tokens.shape == (bs, seq_len)

AssertionError:

@chrisc36 @robooootx did you ever find a solution to this?

zcczhang commented 9 months ago

for error TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in asarray It is mostly due to import orbax.checkpoint. Since this project has taken quite a long time, some of the packages we used are from older versions; recently we found that using pip install dependencies with Python 3.9 would indeed cause conflicts between the Jax and orbax.checkpoint if specifying dtype="bfloat16", but it still works with Python 3.8 (e.g., 3.8.10, which is the default in TPU VMs). After downgrading Python to 3.8, please also downgrade pyglove==0.4.3 which is required by seqio and the latest version released 3 weeks ago only supports Python 3.9. We'll look into this dependency issue more deeply but feel free to use this workaround for now!

For AssertionError, I haven't met this with recent debugging. Could you please share more details? Can you provide a minimal script to reproduce? With the above change, there's no need to change dtype in t5x/examples/unified_io/t5_1_1/{model_size}.gin and t5x/examples/unified_io.config.py to float32 and you can directly set supports_bfloat16 = False or True if using GPU.

KevinKokinda commented 9 months ago

I only seem to get AssertionError when attempting to use float32. With your solution above using bfloat16 worked for me.