tensorflow / hub

A library for transfer learning by reusing parts of TensorFlow models.
https://tensorflow.org/hub
Apache License 2.0
3.49k stars 1.67k forks source link

SentencePiece BPE text embedding feature column #833

Closed vitalyli closed 2 years ago

vitalyli commented 2 years ago

For your review Sentencepiece BPE text embedding feature column.

It's formulated as generic solution, and it's not a model hosted by hub, but it falls into hub realm or base TF layers. https://github.com/google/sentencepiece/blob/master/tensorflow/README.md The input is byte_list as part of ELWC tfrecord, see fragment example below:

feature {
  key: "prefix"
  value {
   bytes_list {
    value: "coffee in mountain view"
   }
  }
 }

Usage example:

feature_columns["prefix_emb"] = SPMEmbeddingColumn('prefix', 'spm.model', dimension=10, combiner='mean',
        initializer=init_ops.glorot_normal_initializer())

from tensorflow.python.training import checkpoint_utils
from tensorflow.python.feature_column import feature_column_v2 as fc2
from tensorflow.python.feature_column import feature_column as fc
import math
import tensorflow_text as text
from tensorflow_hub import module

class DenseFeatureColumn(fc2.DenseColumn):

    @property
    def dtype(self):
        return tf.float32

class SPMEmbeddingColumn(
    DenseFeatureColumn,
    fc2.SequenceDenseColumn,
    collections.namedtuple(
        'SPMEmbeddingColumn',
        ('key','spm_model_file','dimension','combiner', 'initializer'))):
    """See `embedding_column`."""

    def __init__(self, key, spm_model_file, dimension, combiner, initializer):
        print("Init SPMEmbeddingColumn for: ", self.key)

        with open(spm_model_file,'rb') as spm_file:
            spm_model = spm_file.read()
            self.spm1 = text.SentencepieceTokenizer(model=spm_model)

            print("Loaded SPM model from ", spm_model_file)

        super().__init__()

    @property
    def _is_v2_column(self):
        return True

    @property
    def parents(self):
        return [self.key]

    @property
    def name(self):
        return self.key

    def _transform_feature(self, inputs):
        """Returns intermediate representation (usually a `Tensor`)."""
        return inputs.get(self.key)

    def transform_feature(self, transformation_cache, state_manager):
        return transformation_cache.get(self.key, state_manager)

    @property
    def _parse_example_spec(self):
        """Returns a `tf.Example` parsing spec as dict."""
        return {self.key: tf.compat.v1.FixedLenFeature([1], tf.string)}

    @property
    def parse_example_spec(self):
        """Returns a `tf.Example` parsing spec as dict."""
        return {self.key: tf.compat.v1.FixedLenFeature([1], tf.string)}

    @property
    def variable_shape(self):
        """`TensorShape` of `_get_dense_tensor`, without batch dimension."""
        return tf.TensorShape([self.dimension])

    def create_state(self, state_manager):
        """Creates the embedding lookup variable."""
        vocab_size = None #30000
        with tf.compat.v1.Session() as sess:
            vocab_size = sess.run([self.spm1.vocab_size()])

        embedding_shape = (vocab_size[0], self.dimension)

        state_manager.create_variable(
            self,
            name='embedding_weights',
            shape=embedding_shape,
            dtype=tf.float32,
            trainable=True,
            use_resource=False,
            initializer=init_ops.glorot_normal_initializer() if self.initializer is None else self.initializer)

    def _get_dense_tensor_internal_helper(self, sparse_tensors, embedding_weights):

        sparse_ids = sparse_tensors

        # Return embedding lookup result.
        return embedding_ops.safe_embedding_lookup_sparse(
            embedding_weights=embedding_weights,
            sparse_ids=sparse_ids,
            sparse_weights=None,
            combiner=self.combiner,
            name='%s_weights' % self.name,
            max_norm=None)

    def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
        """Private method that follows the signature of get_dense_tensor."""
        embedding_weights = state_manager.get_variable(
            self, name='embedding_weights')
        return self._get_dense_tensor_internal_helper(sparse_tensors, embedding_weights)

    def get_dense_tensor(self, transformation_cache, state_manager):
        # Get sparse IDs and weights.
        prefixes = transformation_cache.get(self.key, state_manager)

        ragged_tensor_id_all = self.spm1.tokenize(prefixes)

        sparse_ids = ragged_tensor_id_all.to_sparse()

        return self._get_dense_tensor_internal(sparse_ids, state_manager)

    def get_sequence_dense_tensor(self, inputs, weight_collections=None, trainable=None):

        dense_tensor = self._get_dense_tensor_internal(
            inputs=inputs,
            weight_collections=weight_collections,
            trainable=True)

        sequence_length = dense_tensor.get_shape()[1:]

        return fc._SequenceDenseColumn.TensorSequenceLengthPair(
            dense_tensor=dense_tensor, sequence_length=sequence_length)

    def _get_config(self):
        """See 'FeatureColumn` base class."""
        config = dict(zip(self._fields, self))
        config['spm_emb_column'] = fc.serialize_feature_column(self)
        config['initializer'] = fc.utils.serialize_keras_object(self.initializer)
        return config

    @classmethod
    def _from_config(cls, config, custom_objects=None, columns_by_name=None):
        """See 'FeatureColumn` base class."""
        fc._check_config_keys(config, cls._fields)
        kwargs = config.copy()
        kwargs['spm_emb_column'] = fc2.deserialize_feature_column(
            config['spm_emb_column'], custom_objects, columns_by_name)
        if config['initializer']:
            kwargs['initializer'] = fc.utils.deserialize_keras_object(
                config['initializer'], custom_objects=custom_objects)
        else:
            kwargs['initializer'] = None
        return cls(**kwargs)
vitalyli commented 2 years ago

Tensors flow.. as follows: Given batch size of 128 prefixes:

  1. Prefixes shape (128, 1) - pulled from ELWC context
  2. spm1.tokenize (prefixes) -> returns RaggedTensor with BPE ids of shape (128, 1, None)
  3. ragged_tensor.to_sparse() -> shape (128, 1, None) i.e.: shape(128, 1, 6), where last dim depends on max size in given batch
  4. Embedding lookup pulls out (128, 1, 6, D) where D is embedding dim. i.e. given 10; then concrete case would be (128, 1, 6, 10)
  5. Mean over Nx10 applied and we get (128, 1, 10) or 10 dim mean vector for each batch of prefix strings.
  6. Example of Tensorboard 30k by 10 PCA projection below Screen Shot 2021-12-23 at 10 54 25 AM
maringeo commented 2 years ago

Hi @vitalyli, thank you for filing this issue! Since I can't spot any errors, I guess the issue is not about a bug in TF Hub? Are you trying to get feedback on the SPMEmbeddingColumn code?

maringeo commented 2 years ago

Hi @vitalyli, since the SPMEmbeddingColumn does not depend on TF Hub, https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/feature_column is probably the most suitable location for this contribution. Could you try opening a pull request to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/feature_column ?

I'll close the TF Hub issue for now, but please reopen if you have any questions.