google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.1k stars 816 forks source link

Cannot export model from trax to TF #1718

Open AndriCcos opened 3 years ago

AndriCcos commented 3 years ago

Description

I would like to export trax trained model as tf object, to serve it in tensorflow serving ...

Environment information

Google Colab

$ pip freeze | grep trax
trax==1.4.1

$ pip freeze | grep tensor
mesh-tensorflow==0.1.19
tensor2tensor==1.15.7
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow @ file:///tensorflow-2.7.0-cp37-cp37m-linux_x86_64.whl
tensorflow-addons==0.15.0
tensorflow-datasets==4.0.1
tensorflow-estimator==2.7.0
tensorflow-gan==2.1.0
tensorflow-gcs-config==2.7.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.22.0
tensorflow-metadata==1.4.0
tensorflow-probability==0.7.0
tensorflow-text==2.7.3

$ pip freeze | grep jax
jax==0.2.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.74+cuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl

$ python -V
Python 3.7.12

For bugs: reproduction and error logs


# Steps to reproduce:
Followed this guide to develop an NMT with transformer model:
https://colab.research.google.com/github/OmarAlsaqa/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb

After training the model, I attempted to save it using this guide
https://trax-ml.readthedocs.io/en/latest/notebooks/tf_numpy_and_keras.html#2.-Convert-Trax-to-Keras

However, an error came up

# Error logs:

---------------------------------------------------------------------------
StagingError                              Traceback (most recent call last)
<ipython-input-103-7c737325b026> in <module>()
      1 # Create a full Keras  model using the layer from Trax.
      2 inputs = tf.keras.Input(shape=(None,), dtype='int32')
----> 3 hidden = keras_layer(inputs)
      4 # You can add other Keras layers here operating on hidden.
      5 outputs = hidden

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    697       except Exception as e:  # pylint:disable=broad-except
    698         if hasattr(e, 'ag_error_metadata'):
--> 699           raise e.ag_error_metadata.to_exception(e)
    700         else:
    701           raise

StagingError: Exception encountered when calling layer "as_keras_3" (type AsKeras).

in user code:

    File "/usr/local/lib/python3.7/dist-packages/trax/trax2keras.py", line 184, in call  *
        outputs, new_state = self._trax_layer.pure_fn(inputs, weights=weights,
    File "/usr/local/lib/python3.7/dist-packages/trax/layers/base.py", line 605, in pure_fn  *
        raise LayerError(name, 'pure_fn',

    LayerError: Exception passing through layer Serial (in pure_fn):
      layer created in file [...]/trax/models/transformer.py, line 390
      layer input shapes: ShapeDtype{shape:(None, None), dtype:<class 'numpy.int32'>}

      File [...]/autograph/operators/control_flow.py, line 1324, in if_stmt
        _py_if_stmt(cond, body, orelse)

      File [...]/autograph/operators/control_flow.py, line 1377, in _py_if_stmt
        return body() if cond else orelse()

      File [...]//tmp/__autograph_generated_file26d10zeu.py, line 61, in if_body_2
        outputs = ag__.converted_call(ag__.ld(self).forward, (ag__.ld(x),), None, fscope)

      File [...]/autograph/impl/api.py, line 447, in converted_call
        result = converted_f(*effective_args)

      File [...]//tmp/__autograph_generated_fileiyqcj2t8.py, line 11, in tf__forward
        ag__.converted_call(ag__.ld(self)._validate_forward_inputs, (ag__.ld(xs),), None, fscope)

      File [...]/autograph/impl/api.py, line 447, in converted_call
        result = converted_f(*effective_args)

      File [...]//tmp/__autograph_generated_files2ty70o0.py, line 20, in tf___validate_forward_inputs
        ag__.if_stmt(ag__.and_((lambda : ag__.not_(ag__.converted_call(ag__.ld(isinstance), (ag__.ld(xs), (ag__.ld(tuple), ag__.ld(list))), None, fscope))), (lambda : (ag__.ld(self)._n_in != 1))), if_body, else_body, get_state, set_state, (), 0)

      File [...]/autograph/operators/control_flow.py, line 1324, in if_stmt
        _py_if_stmt(cond, body, orelse)

      File [...]/autograph/operators/control_flow.py, line 1377, in _py_if_stmt
        return body() if cond else orelse()

      File [...]//tmp/__autograph_generated_files2ty70o0.py, line 16, in if_body
        raise ag__.converted_call(ag__.ld(TypeError), (f'Serial.forward input must be a tuple or list; instead got {ag__.converted_call(ag__.ld(type), (ag__.ld(xs),), None, fscope)}.',), None, fscope)

    TypeError: Serial.forward input must be a tuple or list; instead got <class 'tensorflow.python.framework.ops.Tensor'>.

Call arguments received:
  • inputs=tf.Tensor(shape=(None, None), dtype=int32)
AndriCcos commented 3 years ago

To answer my question, this can be achieved by running the following code. Make sure to set mode to eval as otherwise it will not work as intended.

model = trax.models.Transformer( input_vocab_size=32768, d_model=512, d_ff=1024, n_heads=8, n_encoder_layers=6, n_decoder_layers=6, max_len=1024, mode='eval' )

model.init_from_file( model_folder+'/model.pkl.gz' )

keras_layer = trax.AsKeras(model, batch_size=1) inputs = tf.keras.Input(shape=(1024,), dtype='int32') hidden = keras_layer((inputs, inputs))

outputs = hidden keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)