tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.5k stars 3.49k forks source link

Does Universal Transformer work with TPUs? #1250

Open cbockman opened 5 years ago

cbockman commented 5 years ago

Description

Should Universal Transformer work with TPU? Tried a spin at getting it to work and isn't.

Model + hparams below.

I do see that there are _tpu specific hparam sets for Transformer, and none for UT, which might be a sign that things are not functional on TPU; OTOH I see that greedy_infer (https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/universal_transformer.py#L223) does notionally have some TPU support.

Offhand, doesn't look like any of the transformer hparam changes should be required for running (https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py#L1991) (at least not for the errors I see below).

No problem obviously if not supported, would like to know, however, if we're doing something wrong here.

NOTE: was able to get other models (e.g., base transformer) to apparently run fine on TPUs, and UT works great on GPU.

Thanks!

Environment information

tf 1.12 latest pypi t2t (1.11) TPUv2 on GKE

For bugs: reproduction and error logs

        'problem': 'translate_ende_wmt32k_packed',
        'model': 'universal_transformer',
        'hparams_set': 'universal_transformer_base',
        'eval_steps': 3,
        'train_steps': 1000,

t2t-trainer --data_dir="t2t_data_dir" --eval_steps="3" --generate_data="True" --hparams="" --hparams_set="universal_transformer_base" --model="universal_transformer" --output_dir="output_dir" --problem="translate_ende_wmt32k_packed" --tmp_dir="t2t_tmp_dir" --train_steps="1000" --use_tpu="True" --use_tpu_estimator="True"

Error logs:

WARNING:tensorflow:... and 26 more
INFO:tensorflow:Error recorded from training_loop: Cannot create a gradient accumulator for tensor 'universal_transformer/parallel_0_9/universal_transformer/universal_transformer/body/decoder/universal_transformer_basic/foldl/while/ffn/layer_postprocess/dropout/Floor:0' inside XLA while_loop because maximum_iterations was not passed to the tf.while_loop call ('universal_transformer/parallel_0_9/universal_transformer/universal_transformer/body/decoder/universal_transformer_basic/foldl/while/while_context').
INFO:tensorflow:training_loop marked as finished
/usr/local/lib/python3.6/site-packages/tensorflow/python/platform/tf_logging.py:120: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
  _get_logger().warn(msg, *args, **kwargs)
WARNING:tensorflow:Reraising captured error
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2323, in get_attr
    c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Operation 'universal_transformer/parallel_0_9/universal_transformer/universal_transformer/body/decoder/universal_transformer_basic/foldl/while/ffn/layer_postprocess/dropout/mul' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 403, in _MaybeCompile
    xla_compile = op.get_attr("_XlaCompile")
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2327, in get_attr
    raise ValueError(str(e))
ValueError: Operation 'universal_transformer/parallel_0_9/universal_transformer/universal_transformer/body/decoder/universal_transformer_basic/foldl/while/ffn/layer_postprocess/dropout/mul' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/bin/t2t-trainer", line 33, in <module>
    tf.app.run()
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "/usr/local/bin/t2t-trainer", line 28, in main
    t2t_trainer.main(argv)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/bin/t2t_trainer.py", line 387, in main
    execute_schedule(exp)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/bin/t2t_trainer.py", line 349, in execute_schedule
    getattr(exp, FLAGS.schedule)()
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/trainer_lib.py", line 438, in continuous_train_and_eval
    self._eval_spec)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 471, in train_and_evaluate
    return executor.run()
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 610, in run
    return self.run_local()
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 711, in run_local
    saving_listeners=saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2409, in train
    rendezvous.raise_errors()
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/error_handling.py", line 128, in raise_errors
    six.reraise(typ, value, traceback)
  File "/usr/local/lib/python3.6/site-packages/six.py", line 693, in reraise
    raise value
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2403, in train
    saving_listeners=saving_listeners
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1207, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1237, in _train_model_default
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2195, in _call_model_fn
    features, labels, mode, config)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1195, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2502, in _model_fn
    _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2828, in _train_on_tpu_system
    device_assignment=ctx.device_assignment)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu.py", line 881, in shard
    name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu.py", line 507, in replicate
    device_assignment, name)[1]
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu.py", line 684, in split_compile_and_replicate
    outputs = computation(*computation_inputs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 2821, in multi_tpu_train_steps_on_single_shard
    single_tpu_train_step, [_INITIAL_LOSS])
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/training_loop.py", line 207, in repeat
    cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/training_loop.py", line 169, in while_loop
    name="")
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3291, in while_loop
    return_same_structure)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3004, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2939, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/training_loop.py", line 120, in body_wrapper
    outputs = body(*(inputs + dequeue_ops))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/training_loop.py", line 203, in body_wrapper
    return [i + 1] + _convert_to_list(body(*args))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 1295, in train_step
    self._call_model_fn(features, labels))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py", line 1533, in _call_model_fn
    estimator_spec = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py", line 1264, in wrapping_model_fn
    use_tpu=use_tpu)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py", line 1386, in estimator_model_fn
    loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py", line 1411, in estimator_spec_train
    use_tpu=use_tpu)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/t2t_model.py", line 587, in optimize
    train_op = optimize.optimize(loss, lr, self.hparams, use_tpu=use_tpu)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/optimize.py", line 81, in optimize
    colocate_gradients_with_ops=True)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/layers/python/layers/optimizers.py", line 239, in optimize_loss
    colocate_gradients_with_ops=colocate_gradients_with_ops)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py", line 140, in compute_gradients
    return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensor2tensor/utils/optimize.py", line 149, in compute_gradients
    gradients = self._opt.compute_gradients(loss, var_list, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 519, in compute_gradients
    colocate_gradients_with_ops=colocate_gradients_with_ops)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 630, in gradients
    gate_gradients, aggregation_method, stop_gradients)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 814, in _GradientsHelper
    lambda: grad_fn(op, *out_grads))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 408, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 814, in <lambda>
    lambda: grad_fn(op, *out_grads))
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py", line 936, in _MulGrad
    return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 5042, in mul
    "Mul", x=x, y=y, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3274, in create_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1807, in __init__
    self._control_flow_post_processing()
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1818, in _control_flow_post_processing
    self._control_flow_context.AddOp(self)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2481, in AddOp
    self._AddOpInternal(op)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2502, in _AddOpInternal
    real_x = self.AddValue(x)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2434, in AddValue
    real_val = grad_ctxt.grad_state.GetRealValue(val)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1146, in GetRealValue
    history_value = cur_grad_state.AddForwardAccumulator(cur_value)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1011, in AddForwardAccumulator
    value, self.forward_context)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 741, in GetMaxSizeFromNestedMaximumIterations
    "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
ValueError: Cannot create a gradient accumulator for tensor 'universal_transformer/parallel_0_9/universal_transformer/universal_transformer/body/decoder/universal_transformer_basic/foldl/while/ffn/layer_postprocess/dropout/Floor:0' inside XLA while_loop because maximum_iterations was not passed to the tf.while_loop call ('universal_transformer/parallel_0_9/universal_transformer/universal_transformer/body/decoder/universal_transformer_basic/foldl/while/while_context').

...
afrozenator commented 5 years ago

Paging the helpful @MostafaDehghani :)

MostafaDehghani commented 5 years ago

Hi @afrozenator, Hi @cbockman :) Sorry for my late reaction to this!

I actually never tried UT on TPU. I talked to @lukaszkaiser about this and there were 2 problems: For one, TF does not pass correctly maximum_iterations in foldl and Lukasz sent a CL to correct that, it's actually needed for any ‍‍‍foldl on TPU. The good news is that this change is in: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/functional_ops.py#L147 After that, there's another problem: LA compilation requires that operator argument that represents shapes or dimensions be evaluated to concrete values at compile time. Currently, we have add_step_timing_signal that gets step which is non-static. To avoid this, it's pretty simple to change the UT base model (e.g. replacing the foldl with a for loop with shared parameters), but should be more difficult for the UT with ACT. To run it on TPUs for now, we can just disable step-embedding, so with tf-nightly, you can use the TPU config that @lukaszkaiser has added here https://github.com/tensorflow/tensor2tensor/commit/7a2f3114a60a82a5f97e6a2660d9510689d2f061.

cbockman commented 5 years ago

Thank you @MostafaDehghani ! I figured there was something related to static steps...if it wasn't just us doing something dumb on our end.

Re:disabling step embedding, should we expect that to have significant performance impact? Or is this an unknown?

Just trying to forecast performance, since when things drop, it can be very murky to figure out if we did something wrong, or if it is just inherent to the problem/data/model/...