tensorflow / text

Making text a first-class citizen in TensorFlow.
https://www.tensorflow.org/beta/tutorials/tensorflow_text/intro
Apache License 2.0
1.21k stars 333 forks source link

Error when checkpointing a dataset that uses SentencepieceTokenizer #1289

Open chrisc36 opened 1 week ago

chrisc36 commented 1 week ago

I am running into a error when checkpointing a tf.data.Dataset iterator that uses a SentencepieceTokenizer for tokenization. It fails with:

tensorflow.python.framework.errors_impl.FailedPreconditionError: {{function_node __wrapped__SerializeIterator_device_/job:localhost/replica:0/task:0/device:CPU:0}} SentencepieceTokenizeOp is stateful. [Op:SerializeIterator] name:

As a result I cannot checkpoint datasets that use SentencepieceTokenizer. Is there a fix of work-around that would resolve the issue for me? I saw https://github.com/tensorflow/text/blob/dd919053e7d3e7fddc2dd8e8abccdd74d259a7a1/tensorflow_text/core/kernels/sentencepiece_kernels.cc#L404 which makes it looks like this supposed to be possible.

Code to reproduce the issue:

import tensorflow as tf
import tensorflow_text as tf_text

  with open("/path/to/tokenizer.model", "rb") as f:
      sp_model = f.read()
  tokenizer = tf_text.SentencepieceTokenizer(sp_model)
  ds = tf.data.Dataset.from_tensor_slices(dict(data=["ex1", "ex2", "ex3",]))

  def _map(ex):
      return dict(data=tokenizer.tokenize(ex["data"]))

  ds: tf.data.Dataset = ds.map(_map)
  iterator = iter(ds)
  ckpt = tf.train.Checkpoint(iterator=iterator)
  ckpt.write("/tmp/iterator")

colab: https://colab.research.google.com/drive/1kGYP4GJ2YVGBVQaxNzcIm1M3VxO9yRse?authuser=1#scrollTo=nZ5PVQk-BRP7