zihangdai / xlnet

XLNet: Generalized Autoregressive Pretraining for Language Understanding
Apache License 2.0
6.18k stars 1.18k forks source link

How to export? #113

Closed jinamshah closed 5 years ago

jinamshah commented 5 years ago

I needed to know how to write the serving function to export the trained xlnet model. I have this right now:

def serving_input_fn(): with tf.variable_scope("model"): feature_spec = { "input_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64), "input_mask": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64), "segment_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64), "label_ids": tf.FixedLenFeature([], tf.int64), } serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[None], name='input_example_tensor') receiver_tensors = {'examples': serialized_tf_example} features = tf.parse_example(serialized_tf_example, feature_spec) return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

EXPORT_DIR = 'gs://{}/export/{}'.format(BUCKET, TASK_VERSION) estimator._export_to_tpu = False # this is important path = estimator.export_savedmodel(EXPORT_DIR, serving_input_fn)

This is throwing me errors. Please note: this is the function that I used for Bert, and as I am no expert in tensorflow, I don't understand why it won't work. It throws a type mismatch error

lukemelas commented 5 years ago

Hello @jinamshah , have you resolved this issue? If not, let me know and I can help debug.

jinamshah commented 5 years ago

Hey @lukemelas I was able to resolve this issue. I wrote my own function to help me do this. Thanks!

lukemelas commented 5 years ago

Great!

lukemelas commented 5 years ago

I believe this issue can be closed.

ashgorithm commented 5 years ago

I was also getting type mismatch errors while exporting model.

TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [int32, int64] that don't all match.

How did you resolve this issue?

jinamshah commented 5 years ago

@ashgorithm
for name in list(features.keys()): t = features[name] if t.dtype == tf.int64: t = tf.cast(t, tf.int32) features[name] = t I believe this would help you. Basically, the input is required to be int32

AndrewPelton commented 5 years ago

@jinamshah would you be able to share your solution? I am having the same problem

hexiaoyupku commented 5 years ago

@jinamshah would you be able to share your solution? I am having the same problem

+1, please!

kobkrit commented 5 years ago

@jinamshah would you be able to share your solution? I am having the same problem +1 Please!!

kobkrit commented 5 years ago

Hey @AndrewPelton @hexiaoyupku I solved it :) Thank you to @jinamshah

  def serving_input_fn():
        with tf.variable_scope("foo"):
          feature_spec = {
              "unique_ids": tf.FixedLenFeature([], tf.int64),
              "input_ids": tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
              "input_mask": tf.FixedLenFeature([FLAGS.max_seq_length], tf.float32),
              "segment_ids": tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
              "cls_index": tf.FixedLenFeature([], tf.int64),
              "p_mask": tf.FixedLenFeature([FLAGS.max_seq_length], tf.float32)
          }

          serialized_tf_example = tf.placeholder(dtype=tf.string,
                                                 shape=[FLAGS.predict_batch_size], #[None],
                                                 name='input_example_tensor')

          receiver_tensors = {'examples': serialized_tf_example}
          features = tf.parse_example(serialized_tf_example, feature_spec)

          for name in list(features.keys()): 
            t = features[name] 
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
                features[name] = t

          return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

  if FLAGS.do_export:
    estimator._export_to_tpu = False  # this is important
    print("Estimator save model..")
    estimator.export_savedmodel('export_t', serving_input_fn)
rob-nn commented 4 years ago

Could someone please send me an example of http post request to this serving? thanks!