keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.86k stars 19.44k forks source link

tf.data.Dataset Pipeline Preprocessing Custom Layer Recommendation #20071

Closed apage224 closed 2 months ago

apage224 commented 2 months ago

I am looking to create a number of custom preprocessing layers to be used in a TensorFlow tf.data pipeline. I initially assumed I could subclass keras.Layer and in call simply use any keras.ops operations. I only use python parameters statically for condition statements but otherwise use keras.ops for all commands (e.g. keras.ops.fori_loop). I can run the pipeline alone successfully (e.g. iter(next(train_ds))), however, when I try to train a model using TensorFlow backend it complains with several issues as it's trying to create a symbolic graph of my preprocess layers. These layers are not attached to the model- they are attached to the data pipeline via map. I had assumed the dataset pipeline would happen on the CPU but it seems that my layers are being mapped to the GPU in a TF graph. If I force everything to run on the CPU, everything runs fine but ideally I want the model training to happen on GPU and data pipeline to happen on CPU.

Are there any reference examples I could follow?

When I looked at the included preprocessing layers of keras they all seemed to use keras.backend.numpy for operations (rather than keras.ops). I also noticed the TF pipeline safe layers subclass TFDataLayer which isn't exposed in the public API. Is there a way to indicate to keras that I want to run the entire preprocessing pipeline on the CPU.

Any help would be greatly appreciated.

Below are some layers that I implemented as reference (based on what I could find from both keras and keras-cv:


from typing import Callable
import keras

from .defines import NestedTensorValue

def tf_keras_map(f, xs):
    # NOTE: Workaround until (https://github.com/keras-team/keras/issues/20048)
    import tensorflow as tf

    xs = keras.tree.map_structure(tf.convert_to_tensor, xs)

    def get_fn_output_signature(x):
        out = f(x)
        return keras.tree.map_structure(tf.TensorSpec.from_tensor, out)

    # Grab single element unpacking and repacking single element
    xe = tf.nest.pack_sequence_as(xs, [y[0] for y in tf.nest.flatten(xs)])
    fn_output_signature = get_fn_output_signature(xe)
    return tf.map_fn(f, xs, fn_output_signature=fn_output_signature)

class BaseAugmentation(keras.layers.Layer):
    SAMPLES = "data"
    LABELS = "labels"
    TARGETS = "targets"
    ALL_KEYS = (SAMPLES, LABELS, TARGETS)
    TRANSFORMS = "transforms"
    IS_DICT = "is_dict"
    BATCHED = "is_batched"
    USE_TARGETS = "use_targets"
    NDIMS = 4  # Modify in subclass (includes batch size)

    def __init__(
        self,
        seed: int | None = None,
        auto_vectorize: bool = True,
        data_format: str | None = None,
        name: str | None = None,
        **kwargs,
    ):
        """BaseAugmentation acts as a base class for various custom augmentation layers.
        This class provides a common interface for augmenting samples and labels. In the future, we will
        add support for segmentation and bounding boxes.

        The only method that needs to be implemented by the subclass is

        - augment_sample: Augment a single sample during training.

        Optionally, you can implement the following methods:

        - augment_label: Augment a single label during training.
        - get_random_transformations: Returns a nested structure of random transformations that should be applied to the batch.
            This is required to have unique transformations for each sample in the batch and maintain the same transformations for samples and labels.
        - batch_augment: Augment a batch of samples and labels during training. Needed if layer requires access to all samples (e.g. CutMix).

        By default, this method will coerce the input into a batch as well as a nested structure of inputs.
        If auto_vectorize is set to True, the augment_sample and augment_label methods will be vectorized using keras.ops.vectorized_map.
        Otherwise, it will use keras.ops.map which runs sequentially.

        Args:
            seed (int | None): Random seed. Defaults to None.
            auto_vectorize (bool): If True, augment_sample and augment_label methods will be vectorized using keras.ops.vectorized_map.
                Otherwise, it will use keras.ops.map which runs sequentially. Defaults to True.
            data_format (str | None): Data format. Defaults to None. Will use keras.backend.image_data_format() if None.
            name (str | None): Layer name. Defaults to None.

        """
        super().__init__(name=name, **kwargs)
        self._random_generator = keras.random.SeedGenerator(seed)
        self.data_format = data_format or keras.backend.image_data_format()
        self.built = True
        self.training = True
        self.auto_vectorize = auto_vectorize

    def _map_fn(
        self, func: Callable[[NestedTensorValue], keras.KerasTensor], inputs: NestedTensorValue
    ) -> keras.KerasTensor:
        """Calls appropriate mapping function with given inputs.

        Args:
            func (Callable): Function to be mapped.
            inputs (dict): Dictionary containing inputs.

        Returns:
            KerasTensor: Augmented samples or labels
        """
        if self.auto_vectorize:
            return keras.ops.vectorized_map(func, inputs)
        # NOTE: Workaround until (https://github.com/keras-team/keras/issues/20048)
        if keras.backend.backend() == "tensorflow":
            return tf_keras_map(func, inputs)
        return keras.ops.map(func, inputs)

    def call(self, inputs: NestedTensorValue, training: bool = True) -> NestedTensorValue:
        """This method will serve as the main entry point for the layer. It will handle the input formatting and output formatting.

        Args:
            inputs (NestedTensorValue): Inputs to be augmented.
            training (bool): Whether the model is training or not.

        Returns:
            NestedTensorValue: Augmented samples or labels.
        """
        self.training = training
        inputs, metadata = self._format_inputs(inputs)
        return self._format_outputs(self.batch_augment(inputs), metadata)

    def augment_sample(self, inputs: NestedTensorValue) -> keras.KerasTensor:
        """Augment a single sample during training.

        !!! note

                This method should be implemented by the subclass.
        Args:
            input(NestedTensorValue): Single sample.

        Returns:
            KerasTensor: Augmented sample.
        """
        return inputs[self.SAMPLES]

    def augment_samples(self, inputs: NestedTensorValue) -> keras.KerasTensor:
        """Augment a batch of samples during training.

        Args:
            input(NestedTensorValue): Batch of samples.

        Returns:
            KerasTensor: Augmented batch of samples.
        """
        return self._map_fn(self.augment_sample, inputs=inputs)

    def augment_label(self, inputs: NestedTensorValue) -> keras.KerasTensor:
        """Augment a single label during training.

        !!! note

            Implement this method if you need to augment labels.

        Args:
            input(NestedTensorValue): Single label.

        Returns:
            keras.KerasTensor: Augmented label.
        """
        return inputs[self.LABELS]

    def augment_labels(self, inputs: NestedTensorValue) -> keras.KerasTensor:
        """Augment a batch of labels during training.

        Args:
            inputs(NestedTensorValue): Batch of labels.

        Returns:
            keras.KerasTensor: Augmented batch of labels.
        """
        return self._map_fn(self.augment_label, inputs=inputs)

    def get_random_transformations(self, input_shape: tuple[int, ...]) -> NestedTensorValue:
        """Generates random transformations needed for augmenting samples and labels.

        Args:
            input_shape (tuple[int,...]): Shape of the input (N, ...).

        Returns:
            NestedTensorValue: Batch of random transformations.

        !!! note
                This method should be implemented by the subclass if the layer requires random transformations.
        """
        return keras.ops.arange(input_shape[0])

    def batch_augment(self, inputs: NestedTensorValue) -> NestedTensorValue:
        """Handles processing entire batch of samples and labels in a nested structure.
        Responsible for calling augment_samples and augment_labels.

        Args:
            inputs (NestedTensorValue): Batch of samples and labels.

        Returns:
            NestedTensorValue: Augmented batch of samples and labels.
        """
        samples = inputs.get(self.SAMPLES, None)
        labels = inputs.get(self.LABELS, None)
        result = {}

        transformations = self.get_random_transformations(input_shape=keras.ops.shape(samples))

        result[self.SAMPLES] = self.augment_samples(inputs={self.SAMPLES: samples, self.TRANSFORMS: transformations})

        if labels is not None:
            result[self.LABELS] = self.augment_labels(inputs={self.LABELS: labels, self.TRANSFORMS: transformations})
        # END IF

        # preserve any additional inputs unmodified by this layer.
        for key in inputs.keys() - result.keys():
            result[key] = inputs[key]
        return result

    def _format_inputs(self, inputs: NestedTensorValue) -> tuple[NestedTensorValue, dict[str, bool]]:
        """Validate and force inputs to be batched and placed in structured format.

        Args:
            inputs (NestedTensorValue): Inputs to be formatted.

        Returns:
            tuple[NestedTensorValue, dict[str, bool]]: Formatted inputs and metadata.

        """
        metadata = {self.IS_DICT: True, self.USE_TARGETS: False, self.BATCHED: True}
        if not isinstance(inputs, dict):
            inputs = {self.SAMPLES: inputs}
            metadata[self.IS_DICT] = False

        samples = inputs.get(self.SAMPLES, None)
        if inputs.get(self.SAMPLES) is None:
            raise ValueError(f"Expect the inputs to have key {self.SAMPLES}. Got keys: {list(inputs.keys())}")
        # END IF
        if inputs[self.SAMPLES].shape.rank != self.NDIMS - 1 and samples.shape.rank != self.NDIMS:
            raise ValueError(f"Invalid input shape: {samples.shape}")
        # END IF
        if inputs[self.SAMPLES].shape.rank == self.NDIMS - 1:
            metadata[self.BATCHED] = False
            # Expand dims to make it batched for keys of interest
            for key in set(self.ALL_KEYS).intersection(inputs.keys()):
                if inputs[key] is not None:
                    inputs[key] = keras.ops.expand_dims(inputs[key], axis=0)
                # END IF
            # END FOR
        # END IF
        return inputs, metadata

    def _format_outputs(self, output: NestedTensorValue, metadata: dict[str, bool]) -> NestedTensorValue:
        """Format the output to match the initial input format.

        Args:
            output: Output to be formatted.
            metadata: Metadata used for formatting.

        Returns:
            Output in the original format.
        """
        if not metadata[self.BATCHED]:
            for key in set(self.ALL_KEYS).intersection(output.keys()):
                if output[key] is not None:  # check if tensor
                    output[key] = keras.ops.squeeze(output[key], axis=0)
                # END IF
            # END FOR
        # END IF
        if not metadata[self.IS_DICT]:
            return output[self.SAMPLES]
        if metadata[self.USE_TARGETS]:
            output[self.TARGETS] = output[self.LABELS]
            del output[self.LABELS]
        return output

    def compute_output_shape(self, input_shape, *args, **kwargs):
        """By default assumes the shape of the input is the same as the output.

        Args:
            input_shape: Shape of the input.

        Returns:
            tuple: Shape of the output

        !!! note
                This method should be implemented by the subclass if the output shape is different from the input shape.
        """
        return input_shape

    def get_config(self):
        """Serialize the layer configuration."""
        config = super().get_config()
        config.update(
            {
                "seed": self.seed,
                "auto_vectorize": self.auto_vectorize,
                "data_format": self.data_format,
            }
        )
        return config

class BaseAugmentation1D(BaseAugmentation):
    NDIMS = 3  # (N, T, C) or (N, C, T)

    def __init__(self, **kwargs):
        """BaseAugmentation1D acts as a base class for various custom augmentation layers.
        This class provides a common interface for augmenting samples and labels. In the future, we will
        add support for segmentation and 1D bounding boxes.

        The only method that needs to be implemented by the subclass is

        - augment_sample: Augment a single sample during training.

        Optionally, you can implement the following methods:

        - augment_label: Augment a single label during training.
        - get_random_transformations: Returns a nested structure of random transformations that should be applied to the batch.
            This is required to have unique transformations for each sample in the batch and maintain the same transformations for samples and labels.
        - batch_augment: Augment a batch of samples and labels during training. Needed if layer requires access to all samples (e.g. CutMix).

        By default, this method will coerce the input into a batch as well as a nested structure of inputs.
        If auto_vectorize is set to True, the augment_sample and augment_label methods will be vectorized using keras.ops.vectorized_map.
        Otherwise, it will use keras.ops.map which runs sequentially.

        Example:
        ```python

        class NormalizeLayer1D(BaseAugmentation1D):

            def __init__(self, **kwargs):
                ...

            def augment_sample(self, inputs):
                sample = inputs["data"]
                mu = keras.ops.mean()
                std = keras.ops.std()
                return (sample - mu) / (std + self.epsilon)

        x = np.random.rand(100, 3)
        lyr = NormalizeLayer(...)
        y = lyr(x, training=True)
    """
    super().__init__(**kwargs)

    if self.data_format == "channels_first":
        self.data_axis = -1
        self.ch_axis = -2
    else:
        self.data_axis = -2
        self.ch_axis = -1
    # END IF

class BaseAugmentation2D(keras.layers.Layer): NDIMS = 4 # (N, H, W, C) or (N, C, H, W)

def __init__(self, **kwargs):
    """BaseAugmentation2D acts as a base class for various custom augmentation layers.
    This class provides a common interface for augmenting samples and labels. In the future, we will
    add support for segmentation and 1D bounding boxes.

    The only method that needs to be implemented by the subclass is

    - augment_sample: Augment a single sample during training.

    Optionally, you can implement the following methods:

    - augment_label: Augment a single label during training.
    - get_random_transformations: Returns a nested structure of random transformations that should be applied to the batch.
        This is required to have unique transformations for each sample in the batch and maintain the same transformations for samples and labels.
    - batch_augment: Augment a batch of samples and labels during training. Needed if layer requires access to all samples (e.g. CutMix).

    By default, this method will coerce the input into a batch as well as a nested structure of inputs.
    If auto_vectorize is set to True, the augment_sample and augment_label methods will be vectorized using keras.ops.vectorized_map.
    Otherwise, it will use keras.ops.map which runs sequentially.

    Example:
    ```python

    class NormalizeLayer2D(BaseAugmentation2D):

        def __init__(self, name=None, **kwargs):
            ...

        def augment_sample(self, inputs):
            sample = inputs["data"]
            mu = keras.ops.mean()
            std = keras.ops.std()
            return (sample - mu) / (std + self.epsilon)

    x = np.random.rand(32, 32, 3)
    lyr = NormalizeLayer(...)
    y = lyr(x, training=True)
    ```
    """
    super().__init__(**kwargs)

    if self.data_format == "channels_first":
        self.ch_axis = -3
        self.height_axis = -2
        self.width_axis = -1
    else:
        self.ch_axis = -1
        self.height_axis = -3
        self.width_axis = -2
    # END IF

class RandomNoiseDistortion1D(BaseAugmentation1D): sample_rate: float frequency: tuple[float, float] amplitude: tuple[float, float] noise_type: str

def __init__(
    self,
    sample_rate: float = 1,
    frequency: float | tuple[float, float] = 100,
    amplitude: float | tuple[float, float] = 0.1,
    noise_type: str = "normal",
    **kwargs,
):
    """Apply random noise distortion to the 1D input.
    Noise points are first generated at given frequency resolution with amplitude picked based on noise_type.
    The noise points are then interpolated to match the input duration and added to the input.

    Args:
        sample_rate (float): Sample rate of the input.
        frequency (float|tuple[float,float]): Frequency of the noise in Hz. If tuple, frequency is randomly picked between the values.
        amplitude (float|tuple[float,float]): Amplitude of the noise. If tuple, amplitude is randomly picked between the values.
        noise_type (str): Type of noise to generate. Currently only "normal" is supported.

    Example:
    ```python
        sample_rate = 100 # Hz
        duration = 3*sample_rate # 3 seconds
        sig_freq = 10 # Hz
        sig_amp = 1 # Signal amplitude
        noise_freq = (1, 2) # Noise frequency range
        noise_amp = (1, 2) # Noise amplitude range
        x = sig_amp*np.sin(2*np.pi*sig_freq*np.arange(duration)/sample_rate).reshape(-1, 1)
        lyr = RandomNoiseDistortion1D(sample_rate=sample_rate, frequency=noise_freq, amplitude=noise_amp)
        y = lyr(x, training=True)
    ```
    """

    super().__init__(**kwargs)

    self.sample_rate = sample_rate
    self.frequency = parse_factor(frequency, min_value=None, max_value=sample_rate / 2, param_name="frequency")
    self.amplitude = parse_factor(amplitude, min_value=None, max_value=None, param_name="amplitude")
    self.noise_type = noise_type

def get_random_transformations(self, input_shape: tuple[int, int, int]):
    """Generate noise distortion tensor

    Args:
        input_shape (tuple[int, ...]): Input shape.

    Returns:
        dict: Dictionary containing the noise tensor.
    """
    batch_size = input_shape[0]
    duration_size = input_shape[self.data_axis]
    ch_size = input_shape[self.ch_axis]

    # Add one period to the noise and clip later
    if self.frequency[0] == self.frequency[1]:
        frequency = self.frequency[0]
    else:
        frequency = keras.random.uniform(
            shape=(), minval=self.frequency[0], maxval=self.frequency[1], seed=self._random_generator
        )
    if self.amplitude[0] == self.amplitude[1]:
        amplitude = self.amplitude[0]
    else:
        amplitude = keras.random.uniform(
            shape=(), minval=self.amplitude[0], maxval=self.amplitude[1], seed=self._random_generator
        )

    noise_duration = keras.ops.cast((duration_size / self.sample_rate) * frequency + frequency, dtype="int32")

    if self.data_format == "channels_first":
        noise_shape = (batch_size, 1, ch_size, noise_duration)
    else:
        noise_shape = (batch_size, 1, noise_duration, ch_size)

    if self.noise_type == "normal":
        noise_pts = keras.random.normal(noise_shape, stddev=amplitude, seed=self._random_generator)
    else:
        raise ValueError(f"Invalid noise shape: {self.noise_type}")

    # keras.ops doesnt contain any low-level interpolate. So we leverage the
    # image module and fix height to 1 as workaround
    noise = keras.ops.image.resize(
        noise_pts,
        size=(1, duration_size),
        interpolation="bicubic",
        crop_to_aspect_ratio=False,
        data_format=self.data_format,
    )
    # Remove height dimension
    noise = keras.ops.squeeze(noise, axis=1)
    return {"noise": noise}

def augment_samples(self, inputs) -> keras.KerasTensor:
    """Augment all samples in the batch as it's faster."""
    samples = inputs[self.SAMPLES]
    if self.training:
        noise = inputs[self.TRANSFORMS]["noise"]
        return samples + noise
    return samples

def get_config(self):
    """Serialize the layer configuration to a JSON-compatible dictionary."""
    config = super().get_config()
    config.update(
        {
            "sample_rate": self.sample,
            "frequency": self.frequency,
            "amplitude": self.amplitude,
            "noise_type": self.noise_type,
        }
    )
    return config
james77777778 commented 2 months ago

You might want to check my work: https://github.com/james77777778/keras-aug

I have borrowed the idea of DynamicBackend in Keras and extended it a lot for images and bounding boxes.

hertschuh commented 2 months ago

Hi @apage224 ,

The issue is that tensors do end up being sent to GPU indeed because of some logic in the Layer class.

Here is how Keras NLP makes it work: https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/src/layers/preprocessing/preprocessing_layer.py#L35-L49

For each layer that you to use in tf.data.Dataset, you need this in the __init__:

self._convert_input_args = False
self._allow_non_tensor_positional_args = True

And then, in call, you need to wrap the computations in with tf.device("cpu") so that tensors stay on CPU.

apage224 commented 2 months ago

Thanks @james77777778, I really like your augmentation package. I think I need to handle the dynamic backend as you've mentioned.

Hi @hertschuh, that makes sense why my GPU memory was crazy high. I think I can handle this by making a base layer specifically for tf.data pipelines. Two follow up questions:

apage224 commented 2 months ago

Also it might be useful for Keras to expose a base layer strictly for tf.data pipelines that handles all the underlying logic. keras.src has something already but not in the public API if I remember correctly.

hertschuh commented 2 months ago

@apage224 ,

If my base layer is dynamically setting backend to TensorFlow to execute pipeline is it safe to use keras.ops within the subclass layers or do I need to use strictly backend operations.

I thought you were using Tensorflow as a backend (not just for preprocessing). If so, you can use keras.ops safely.

If you are using a different backend for the model itself, then you'll need something like TFDataLayer, which is based on DynamicBackend. But then, you cannot use keras.ops, you have to use self.backend.numpy. As was noted, these classes are not public. We should think about if and what to expose publicly.

I think I'm missing something basic but if my pipeline layers are all wrapped in tf.device("cpu") do I need anything that does the transition to the GPU when it's then fed into a model on GPU. Something like tf.data.Dataset prefetch?

The tensors will be transitioned to GPU when fed to the model, so I don't think anything is strictly needed. You can add tf.data.Dataset prefetch, but that's purely on CPU, or you can try prefetch_on_device.

apage224 commented 2 months ago

Thanks @hertschuh,

Yes, currently I am using TensorFlow as my main backend but I'd like to future proof for if/when migrate to Jax or PyTorch. I went ahead and made my own version of TFDataLayer. Everything seems to be working okay so far. Thanks for the help.

google-ml-butler[bot] commented 2 months ago

Are you satisfied with the resolution of your issue? Yes No