tensorflow / lingvo

Lingvo
Apache License 2.0
2.81k stars 445 forks source link

Enable fp16 while running transformers with gpipe #98

Open msharmavikram opened 5 years ago

msharmavikram commented 5 years ago

Looking at the log, I noted that the Lingvo-Gpipe LM implementation used fp32 data type by default (for GPUs). Is fp16 supported for this model? If so, how should i enable it?

bignamehyp commented 5 years ago

No, it's not supported nor tested.

To enable fp16, set p.fprop_dtype = tf.float16 in train(cls) and task(cls) under lingvo/tasks/lm/params/one_billion_wds.py.

And you may also need to convert input paddings to tf.float16 in lingvo/tasks/lm/layers.py. When computing softmax, I believe you need to cast everything back to tf.float32 for computing softmax logits.

msharmavikram commented 5 years ago

Thanks for the response. Where do you suggest me to cast fp16 in layers.py? This is what I did - I casted the paddings to tf.float16 before (or after doesn't matter) https://github.com/tensorflow/lingvo/blob/master/lingvo/tasks/lm/layers.py#L1321

However, I end up getting the error -

I0701 23:47:20.677966 140208556619520 gpipe.py:380] cell 0 input [<tf.Tensor 'fprop/1bwds_wpm_level_lm/tower_0_0/recurrent_cellfn_extras/GatherV2_1:0' shape=(1024, 1) dtype=int64>, <tf.Tensor 'fprop/1bwds_wpm_level_lm/tower_0_0/recurrent_cellfn_extras/GatherV2_2:0' shape=(1024, 1) dtype=float16>, None, None, None, None, <tf.Tensor 'fprop/1bwds_wpm_level_lm/tower_0_0/recurrent_cellfn_extras/GatherV2_3:0' shape=(1024, 1) dtype=float16>, <tf.Tensor 'fprop/1bwds_wpm_level_lm/tower_0_0/recurrent_cellfn_extras/GatherV2_4:0' shape=(1024, 1) dtype=float32>, None, None]                                                                                                                                                
I0701 23:47:23.199995 140208556619520 gpipe.py:380] cell 0 input [<tf.Tensor 'arg293:0' shape=(1024, 1) dtype=int64>, <tf.Tensor 'arg294:0' shape=(1024, 1) dtype=float16>, None, None, None, None, <tf.Tensor 'arg295:0' shape=(1024, 1) dtype=float16>, <tf.Tensor 'arg296:0' shape=(1024, 1) dtype=float32>, None, None]                                                                                                                                                               
Traceback (most recent call last):                                                                                                                                                                                                           
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/trainer.py", line 1698, in <module>                                                                                                                                    
    tf.app.run(main)                                                                                                                                                                                                                         
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/platform/app.py", line 40, in run                                                                                                                                      
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)                                                                                                                                                                     
  File "/usr/local/lib/python2.7/dist-packages/absl/app.py", line 300, in run                                                                                                                                                                
    _run_main(main, args)                                                                                                                                                                                                                    
  File "/usr/local/lib/python2.7/dist-packages/absl/app.py", line 251, in _run_main                                                                                                                                                          
    sys.exit(main(argv))                                                                                                                                                                                                                     
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/trainer.py", line 1694, in main                                                                                                                                        
    RunnerManager(FLAGS.model).Start()                                                                                                                                                                                                       
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/trainer.py", line 1687, in Start                                                                                                                                       
    self.StartRunners(self.CreateRunners(FLAGS.job.split(','), FLAGS.logdir))                                                                                                                                                                
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/trainer.py", line 1455, in CreateRunners                                                                                                                               
    trial)                                                                                                                                                                                                                                   
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/trainer.py", line 1409, in _CreateRunner                                                                                                                               
    return self.Controller(cfg, *common_args)                                                                                                                                                                                                
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/trainer.py", line 236, in __init__                                                                                                                                     
    self._model.ConstructFPropBPropGraph()                                                                                                                                                                                                   
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/base_model.py", line 1139, in ConstructFPropBPropGraph                                                                                                            
    self._task.FPropDefaultTheta()                                                                                                                                                                                                           
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/base_model.py", line 505, in FPropDefaultTheta
    return self.FProp(self.theta, input_batch)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/base_model.py", line 424, in FProp
    metrics, per_example = self._FPropSplitInputBatch(theta, input_batch)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/base_model.py", line 470, in _FPropSplitInputBatch
    metrics, per_example = self.FPropTower(theta_local, batch)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/tasks/lm/model.py", line 90, in FPropTower
    xent_output, _ = self.lm.FProp(theta.lm, ids, paddings, state0, labels)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/tasks/lm/layers.py", line 1325, in FProp
    tf.cast(labels.class_ids, py_utils.FPropDtype(p)), labels.class_weights)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/layers_with_gpipe.py", line 712, in FProp
    target_pos_id)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/gpipe.py", line 453, in FProp
    unused_acc_state=True)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/recurrent.py", line 1613, in StackedRecurrent
    accumulator_layer=accumulator_layers[0])
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/recurrent.py", line 1222, in Recurrent
    implicit_captures=implicit_captures).Compute()
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/recurrent.py", line 822, in Compute
    *Flatten([self._theta, self._state, self._inputs, self._extras]))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 563, in __call__
    self.add_to_graph(ops.get_default_graph())
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 544, in add_to_graph
    self._create_definition_if_needed()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 376, in _create_definition_if_needed
    self._create_definition_if_needed_impl()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 407, in _create_definition_if_needed_impl
    capture_resource_var_by_value=self._capture_resource_var_by_value)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 943, in func_graph_from_py_func
    outputs = func(*func_graph.inputs)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/recurrent.py", line 583, in Forward
    body=ForwardLoopBody)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/ops/functional_ops.py", line 646, in While
    if body.captured_inputs:
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 360, in captured_inputs
    self._create_definition_if_needed()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 376, in _create_definition_if_needed
    self._create_definition_if_needed_impl()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 407, in _create_definition_if_needed_impl
    capture_resource_var_by_value=self._capture_resource_var_by_value)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/function.py", line 943, in func_graph_from_py_func
    outputs = func(*func_graph.inputs)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/recurrent.py", line 435, in ForwardLoopBody
    acc_state = _Update(acc_state, state1, t)
  File "/tmp/lingvo/bazel-bin/lingvo/trainer.runfiles/__main__/lingvo/core/recurrent.py", line 116, in _Update
    lst += [inplace_ops.alias_inplace_update(acc, t, tf.expand_dims(x, 0))]
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/ops/inplace_ops.py", line 90, in alias_inplace_update
    return _inplace_helper(x, i, v, gen_array_ops.inplace_update)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/ops/inplace_ops.py", line 54, in _inplace_helper
    v = ops.convert_to_tensor(v, x.dtype)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/ops.py", line 1087, in convert_to_tensor
    return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/ops.py", line 1145, in convert_to_tensor_v2
    as_ref=False)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_core/python/framework/ops.py", line 1174, in internal_convert_to_tensor
    (dtype.name, value.dtype.name, value))
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float16: <tf.Tensor 'ExpandDims_2:0' shape=<unknown> dtype=float16>

I believe I am doing casting at the wrong place. Any suggestion? @bignamehyp (I noticed there were a few changes you made to support dtype in the layers.py and thanks for that)