tensorflow / lingvo

Lingvo
Apache License 2.0
2.81k stars 442 forks source link

error when run gpipe with lm.one_billion_wds.OneBWdsGPipeTransformerWPM #167

Open gaokai0810 opened 4 years ago

gaokai0810 commented 4 years ago

when I run the gpipe example OneBWdsGPipeTransformerWPM, a error ocurr: Traceback (most recent call last): File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/trainer.py", line 1836, in tf.app.run(main) File "/home/gaokai/miniconda3/envs/gpipe_tf2gpu/lib/python3.6/site-packages/tensorflow_core/python/platform/app.py", line 40, in run _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef) File "/home/gaokai/miniconda3/envs/gpipe_tf2gpu/lib/python3.6/site-packages/absl/app.py", line 299, in run _run_main(main, args) File "/home/gaokai/miniconda3/envs/gpipe_tf2gpu/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main sys.exit(main(argv)) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/trainer.py", line 1831, in main RunnerManager(FLAGS.model).Start() File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/trainer.py", line 1824, in Start self.StartRunners(self.CreateRunners(FLAGS.job.split(','), FLAGS.logdir)) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/trainer.py", line 1570, in CreateRunners trial) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/trainer.py", line 1520, in _CreateRunner return self.Controller(cfg, *common_args) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/trainer.py", line 272, in init self._model.ConstructFPropBPropGraph() File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/core/base_model.py", line 1221, in ConstructFPropBPropGraph self._task.FPropDefaultTheta() File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/core/base_model.py", line 551, in FPropDefaultTheta return self.FProp(self.theta, input_batch) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/core/base_model.py", line 469, in FProp metrics, per_example = self._FPropSplitInputBatch(theta, input_batch) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/core/base_model.py", line 514, in _FPropSplitInputBatch metrics, per_example = self.FPropTower(theta_local, batch) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/tasks/lm/model.py", line 105, in FPropTower xentoutput, = self.lm.FProp(theta.lm, ids, paddings, state0, labels) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/tasks/lm/layers.py", line 1335, in FProp labels.class_weights) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/core/layers_with_gpipe.py", line 846, in FProp source_task_id, target_task_id) File "/home/gaokai/lingvo_2/lingvo-master/bazel-bin/lingvo/trainer.runfiles/main/lingvo/core/gpipe.py", line 417, in FProp if p.num_micro_batches > mini_batch_size: TypeError: '>' not supported between instances of 'int' and 'NoneType'

jonathanasdf commented 4 years ago
input_tenors = _ToTuple(args)
mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim]

Looks like that function expects the inputs to have static shapes for the batch dim. Can you make sure the input has a known static batch dim (eg. with a set_shape() call)