tensorflow / transform

Input pipeline framework
Apache License 2.0
984 stars 213 forks source link

Transform graph returns an empty dictionary at serve_tf_examples_fn function #303

Closed marlon-shiftone closed 1 year ago

marlon-shiftone commented 1 year ago

I have created preprocessing_fn function for processing audio data - along with other functions. I have struggled to make all the functions able to generate a graph, however, I'm still getting the error: ValueError: Missing data for input "audio_xf". You passed a data dictionary with keys []. Expected the following keys: ['audio_xf'] The graph loaded after preprocessing the data, transformed_features = model.tft_layer(parsed_features), returns an empty dictionary. I'm not sure if it is due to the nature of my function definitions and that's why I'm creating an issue - I would like to know at least the reason for the bug so I can try to solve it - spending more hours on it.

Here comes my preprocessing_fn function:


import tensorflow as tf
from tensorflow.python.framework import ops
from util import transformed_name
from tensorflow_transform import common
from tensorflow_transform import common_types
from typing import Dict, Optional, Any, Union

@common.log_api_use(common.ANALYZER_COLLECTION)
def normalize_tensorflow(
    S: common_types.TensorType,
    hparams: Dict[str, Any],
    name: Optional[str] = None) -> tf.Tensor:
  """Normalizes the input tensor `S`.

  This function subtracts `min_level_db` from `S` and then divides by the
  negative of `min_level_db`. The resulting values are then clipped to the
  range [0, 1].

  Args:
    S: A `Tensor`, `SparseTensor`, or `RaggedTensor`.
    hparams: A dictionary containing hyperparameters. Must include
      'min_level_db'.
    name: (Optional) A name for this operation.

  Returns:
    A `Tensor` with the same type as `S`.

  Raises:
    KeyError: If 'min_level_db' is not found in `hparams`.
    TypeError: If the type of `S` is not supported.
  """
  with tf.compat.v1.name_scope(name, 'normalize'):
    min_level_db = hparams['min_level_db']
    return tf.clip_by_value((S - min_level_db) / -min_level_db, 0, 1)

@common.log_api_use(common.ANALYZER_COLLECTION)
def tf_log10(x: Union[ops.Tensor, tf.SparseTensor, tf.RaggedTensor], 
              name: Optional[str] = None) -> tf.Tensor:
    """
    Computes the base 10 logarithm of `x`.

    Args:
        x: A `Tensor`, `SparseTensor`, or `RaggedTensor`.
        name: (Optional) A name for this operation.

    Returns:
        A `Tensor` with the same type as `x`.

    Raises:
        TypeError: If the type of `x` is not supported.
    """
    with tf.compat.v1.name_scope(name, 'tf_log10'):
        numerator = tf.math.log(x)
        denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
        return numerator / denominator

@common.log_api_use(common.ANALYZER_COLLECTION)
def amp_to_db_tensorflow(x: Union[ops.Tensor, tf.SparseTensor, tf.RaggedTensor], 
                          name: Optional[str] = None) -> tf.Tensor:
    """
    Converts amplitude to decibels for `x`.

    Args:
        x: A `Tensor`, `SparseTensor`, or `RaggedTensor`.
        name: (Optional) A name for this operation.

    Returns:
        A `Tensor` with the same type as `x`.

    Raises:
        TypeError: If the type of `x` is not supported.
    """
    with tf.compat.v1.name_scope(name, 'amp_to_db_tensorflow'):
        return 20 * tf_log10(tf.clip_by_value(tf.abs(x), 1e-5, 1e100))

@common.log_api_use(common.ANALYZER_COLLECTION)
def stft_tensorflow(signals: Union[ops.Tensor, tf.SparseTensor, tf.RaggedTensor], 
                    hparams: Dict[str, int],
                     name: Optional[str] = None) -> tf.Tensor:
    """
    Computes the short-time Fourier transform of `signals`.

    Args:
        signals: A `Tensor`, `SparseTensor`, or `RaggedTensor`.
        hparams: A dictionary with keys 'win_length', 'hop_length', and 'n_fft'.
        name: (Optional) A name for this operation.

    Returns:
        A `Tensor` with the same type as `signals`.

    Raises:
        TypeError: If the type of `signals` is not supported.
    """
    with tf.compat.v1.name_scope(name, 'stft_tensorflow'):
        return tf.signal.stft(
            signals,
            hparams['win_length'],
            hparams['hop_length'],
            hparams['n_fft'],
            pad_end=True,
            window_fn=tf.signal.hann_window,
        )

@common.log_api_use(common.ANALYZER_COLLECTION)
def mel_spectrogram(tensor: Union[ops.Tensor, tf.SparseTensor, tf.RaggedTensor], 
                    hparams: Dict[str, Union[int, float]],
                    name: Optional[str] = None) -> tf.Tensor:
    """
    Computes the Mel Spectrogram of `tensor`.

    Args:
        tensor: A `Tensor`, `SparseTensor`, or `RaggedTensor`.
        hparams: A dictionary with keys 'ref_level_db', 'num_mel_bins', 'sample_rate', 
                 'mel_lower_edge_hertz', 'mel_upper_edge_hertz'.
        name: (Optional) A name for this operation.

    Returns:
        A `Tensor` with the same type as `tensor`.

    Raises:
        TypeError: If the type of `tensor` is not supported.
    """
    with tf.compat.v1.name_scope(name, 'mel_spectrogram'):
        # Process the audio
        D = stft_tensorflow(tensor, hparams)
        S = amp_to_db_tensorflow(tf.abs(D)) - hparams['ref_level_db']
        S = normalize_tensorflow(S, hparams)

        # Calculate the mel weight matrix
        mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=hparams['num_mel_bins'],
            num_spectrogram_bins=S.shape[-1],
            sample_rate=hparams['sample_rate'],
            lower_edge_hertz=hparams['mel_lower_edge_hertz'],
            upper_edge_hertz=hparams['mel_upper_edge_hertz'],
            dtype=S.dtype,
        )

        # Apply the mel weight matrix to the spectrogram
        mel_spectrogram = tf.tensordot(S, mel_weight_matrix, 1)

        return tf.reshape(mel_spectrogram, [-1, 1067, 128])

@common.log_api_use(common.ANALYZER_COLLECTION)
def convert_labels(label: tf.Tensor, name: Optional[str] = None) -> tf.Tensor:
    """
    Converts the input 'label' tensor into a binary label and one-hot encodes it.

    Args:
        label: A `Tensor` of labels where 'human' corresponds to 1 and any other value to 0.
        name: (Optional) A name for this operation.

    Returns:
        A `Tensor` that is a one-hot encoding of the input `label`.

    Raises:
        TypeError: If the type of `label` is not supported.
    """
    with tf.compat.v1.name_scope(name, 'convert_labels'):
        binary_label = tf.where(label == 'human', 1, 0)
        label_one_hot = tf.one_hot(binary_label, depth=2)
        label_one_hot = tf.reshape(label_one_hot, [-1, 2])

        return label_one_hot

@common.log_api_use(common.ANALYZER_COLLECTION)
def decode_and_pad_audio(audio: Union[ops.Tensor, tf.SparseTensor, tf.RaggedTensor], 
                         hparams: Dict[str, int],
                         name: Optional[str] = None) -> tf.Tensor:
    """
    Decodes and pads `audio`.

    Args:
        audio: A `Tensor`, `SparseTensor`, or `RaggedTensor`.
        hparams: A dictionary with key 'sample'.
        name: (Optional) A name for this operation.

    Returns:
        A `Tensor` with the same type as `audio`.

    Raises:
        TypeError: If the type of `audio` is not supported.
    """
    with tf.compat.v1.name_scope(name, 'decode_and_pad_audio'):
        audio_decoded = tf.io.decode_raw(audio, tf.float32)
        audio_tensor = tf.expand_dims(audio_decoded, axis=0)
        max_padding_size = hparams['sample']

        def pad_audio():
            num_samples = tf.shape(audio_tensor)[1]
            padding_size = tf.math.abs(max_padding_size - num_samples)
            padded = tf.concat([audio_tensor, tf.zeros((1, padding_size), dtype=audio_tensor.dtype)], axis=1)
            return padded[:, :hparams['sample']]

        def slice_audio():
            return audio_tensor[:, :hparams['sample']]

        padded_audio = tf.cond(tf.math.greater_equal(hparams['sample'], tf.shape(audio_tensor)[1]), pad_audio, slice_audio)
        return padded_audio

def preprocessing_fn(inputs):
    @tf.function
    def process_example(example):

        hparams = {  
            # spectrogramming
            'sample':320000,
            'win_length' : 2048,
            'n_fft' : 2048,
            'hop_length': 300,
            'ref_level_db' : 50,
            'min_level_db' : -100,
            # mel scaling
            'num_mel_bins' : 128,
            'mel_lower_edge_hertz' : 20.0,
            'mel_upper_edge_hertz' : 4000.0,
            # inversion
            'power' : 1.5, # for spectral inversion
            'griffin_lim_iters' : 50,
            'pad':True,
            'sample_rate':8000,  # Sample rate parameter added
            #
        }

        audio, label, sample_rate, fmax, n_mels, hop_length, n_fft, fmin, sample = example

        #Decode and pad the vectors
        audio_reshaped = tf.map_fn(lambda x: decode_and_pad_audio(x, hparams), audio, dtype=tf.float32)
        print(audio_reshaped)
        # Encode the label
        one_hot = convert_labels(label)
        # Serialize the labels
        label_serialized = tf.io.serialize_tensor(one_hot)

        # Use tf.py_function to call get_melspectrogram with NumPy operations
        mels = tf.map_fn(lambda x: mel_spectrogram(x, hparams), audio_reshaped, dtype=tf.float32)

        # Serialize the audios
        mels_serialized = tf.io.serialize_tensor(mels)

        return mels_serialized, label_serialized

    # Apply process_example to each input example
    audio_xf, label_xf = tf.map_fn(process_example, (inputs['audio'], inputs['label'], inputs['sample_rate'], inputs['fmax'],
                                       inputs['n_mels'], inputs['hop_length'], inputs['n_fft'], inputs['fmin'],
                                       inputs['sample']), dtype=(tf.string, tf.string))

    return {transformed_name('audio'): audio_xf, transformed_name('label'): label_xf}

Here comes the bug logs:

Traceback (most recent call last):

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File ~/Área de Trabalho/telnyx/audio_classification/data_ingestion.py:78
    context.run(trainer)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tfx/orchestration/experimental/interactive/notebook_utils.py:31 in run_if_ipython
    return fn(*args, **kwargs)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tfx/orchestration/experimental/interactive/interactive_context.py:164 in run
    execution_id = launcher.launch().execution_id

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tfx/orchestration/launcher/base_component_launcher.py:206 in launch
    self._run_executor(execution_decision.execution_id,

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tfx/orchestration/launcher/in_process_component_launcher.py:73 in _run_executor
    executor.Do(

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tfx/components/trainer/executor.py:178 in Do
    run_fn(fn_args)

  File ~/Área de Trabalho/telnyx/audio_classification/module.py:96 in run_fn
    _get_serve_tf_examples_fn(

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:1258 in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:1238 in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:763 in _initialize
    self._variable_creation_fn    # pylint: disable=protected-access

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py:171 in _get_concrete_function_internal_garbage_collected
    concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py:166 in _maybe_define_concrete_function
    return self._maybe_define_function(args, kwargs)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py:396 in _maybe_define_function
    concrete_function = self._create_concrete_function(

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py:300 in _create_concrete_function
    func_graph_module.func_graph_from_py_func(

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1214 in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:667 in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1200 in autograph_handler
    raise e.ag_error_metadata.to_exception(e)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1189 in autograph_handler
    return autograph.converted_call(

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:439 in converted_call
    result = converted_f(*effective_args, **kwargs)

  File /tmp/__autograph_generated_filek1qsmo_7.py:16 in tf__serve_tf_examples_fn
    outputs = ag__.converted_call(ag__.ld(model), (ag__.ld(transformed_features),), None, fscope)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:377 in converted_call
    return _call_unconverted(f, args, kwargs, options)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459 in _call_unconverted
    return f(*args)

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/keras/utils/traceback_utils.py:70 in error_handler
    raise e.with_traceback(filtered_tb) from None

  File ~/anaconda3/envs/tfx/lib/python3.9/site-packages/keras/engine/input_spec.py:197 in assert_input_compatibility
    raise ValueError(

ValueError: in user code:

    File "/home/marlon/Área de Trabalho/telnyx/audio_classification/module.py", line 57, in serve_tf_examples_fn  *
        outputs = model(transformed_features)
    File "/home/marlon/anaconda3/envs/tfx/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/home/marlon/anaconda3/envs/tfx/lib/python3.9/site-packages/keras/engine/input_spec.py", line 197, in assert_input_compatibility
        raise ValueError(

    ValueError: Missing data for input "audio_xf". You passed a data dictionary with keys []. Expected the following keys: ['audio_xf']

Thanks

singhniraj08 commented 1 year ago

@marlon-shiftone,

From the code, I can see you are using custom python functions for your data transformations. You need to use tft.apply_pyfunc to apply the custom functions to the input features. This function is for using inside a preprocessing_fn. It is a wrapper around tf.py_func.

The functions added this way can run in Transform, and during training but TensorFlow Serving will not be able to serve this graph. for that we already have an issue #tfx/3178.

Note: This API can only be used when TF2 is disabled or tft_beam.Context.force_tf_compat_v1=True.

marlon-shiftone commented 1 year ago

@singhniraj08 thanks for the comment. I have already fixed it and moved forward with the code. For the entire discussion check out this link:

https://discuss.tensorflow.org/t/tfx-transform-layer-returning-an-empty-dictionary/16781/9

singhniraj08 commented 1 year ago

@marlon-shiftone,

Requesting you to close this issue, if it's resolved for you. Thanks.

marlon-shiftone commented 1 year ago

The issue has been solved: Moral: for saving the graph - or using it at the serving stage, what is the same here - there should be no Pythonic operation in the preprocessing_fn function. All the Pythonic methods have to be replaced by TensorFlow tf_compat_v1 methods. Also, the variables seem to have to be declared in the scope of the preprocessing_fn function. tf.py_func will work in the training phase, but it will not work in the serving phase, for the first reason.