google / flaxformer

Apache License 2.0
321 stars 31 forks source link

Failed to map logical axes for target/decoder/logits... #2

Open ibulu opened 2 years ago

ibulu commented 2 years ago

I am getting the following error when fine-tuning longT5 model:

` ValueError Traceback (most recent call last) Input In [16], in <cell line: 21>() 14 gin_utils.parse_gin_flags( 15 # User-provided gin paths take precedence if relative paths conflict. 16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, 17 FLAGS.gin_file, 18 FLAGS.gin_bindings) 19 train_using_gin() ---> 21 gin_utils.run(main_train)

File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main) 103 def run(main): 104 """Wrapper for app.run that rewrites gin args before parsing.""" --> 105 app.run( 106 main, 107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser) 310 callback() 311 try: --> 312 _run_main(main, args) 313 except UsageError as error: 314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv) 256 sys.exit(retval) 257 else: --> 258 sys.exit(main(argv))

Input In [15], in main_train(argv) 1 def main_train(argv: Sequence[str]): 2 """Wrapper for pdb post mortems.""" ----> 3 _main(argv)

Input In [16], in _main(argv) 12 train_using_gin = gin.configurable(train) 14 gin_utils.parse_gin_flags( 15 # User-provided gin paths take precedence if relative paths conflict. 16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, 17 FLAGS.gin_file, 18 FLAGS.gin_bindings) ---> 19 train_using_gin()

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else '' 1604 err_str = err_str.format(name, fn_or_cls, scope_info) -> 1605 utils.augment_exception_message_and_reraise(e, err_str)

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message) 39 proxy = ExceptionProxy() 40 ExceptionProxy.qualname = type(exception).qualname ---> 41 raise proxy.with_traceback(exception.traceback) from None

File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, *kwargs) 1579 new_kwargs.update(kwargs) 1581 try: -> 1582 return fn(new_args, **new_kwargs) 1583 except Exception as e: # pylint: disable=broad-except 1584 err_str = ''

Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda) 224 input_types = { 225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() 226 } 227 init_or_restore_tick = time.time() --> 228 train_state_initializer = utils.TrainStateInitializer( 229 optimizer_def=model.optimizer_def, 230 init_fn=model.get_initial_variables, 231 input_shapes=input_shapes, 232 input_types=input_types, 233 partitioner=partitioner) 234 # 3. From scratch using init_fn. 235 train_state = train_state_initializer.from_checkpoint_or_scratch( 236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)

File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types) 365 self._partitioner = partitioner 366 self.global_train_state_shape = jax.eval_shape( 367 initialize_train_state, rng=jax.random.PRNGKey(0)) --> 368 self.train_state_axes = partitioner.get_mesh_axes( 369 self.global_train_state_shape) 370 self._initialize_train_state = initialize_train_state 372 # Currently scanned layers require passing annotations through to the 373 # point of the scan transformation to resolve an XLA SPMD issue. 374 375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model 376 # instance from the bound method.

File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state) 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e 890 flat_logical_axes = traverse_util.flatten_dict( 891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/') --> 892 flat_mesh_axes = { 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() 894 } 896 return logical_axes.restore_state( 897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0) 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e 890 flat_logical_axes = traverse_util.flatten_dict( 891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/') 892 flat_mesh_axes = { --> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() 894 } 896 return logical_axes.restore_state( 897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes) 885 return flax_partitioning.logical_to_mesh_axes(logical_axes, 886 self._logical_axis_rules) 887 except ValueError as e: --> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e

ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel In call to configurable 'train' (<function train at 0x2b751e160>)

`