tensorflow / tfx

TFX is an end-to-end platform for deploying production ML pipelines
https://tensorflow.github.io/tfx/
Apache License 2.0
2.11k stars 708 forks source link

[Discussion] Random Under-sampling for Imbalanced Datasets #3831

Closed RossKohler closed 1 year ago

RossKohler commented 3 years ago

I haven't seen this discussed anywhere so perhaps someone could give me some insight here. When performing classification tasks it's very rare to have a balanced dataset (where every class is equally represented). When training a model, we can mitigate the model's bias for the majority classes through the use of random under sampling. In every 'epoch' we ensure that each class is equally represented by under-sampling every class (to the lowest represented class). This means that the underrepresented classes are passed through the model multiple times, but a bias isn't developed to predict the majority class in inference.

The imbalanced-learn library is useful for achieving this. How can the same be achieved in TFX? I'm not entirely sure I understand how the trainer component selects examples at the moment. Is it at random? or sequentially? How might random under-sampling be implemented in TFX?

Any insight would be greatly appreciated. Thanks!

arghyaganguly commented 3 years ago

Related #1549

axeltidemann commented 3 years ago

This is a very common pattern for real-life datasets. It would be great to have a solution to this, without having to write a custom Trainer or oversampling the data to balance it out.

axeltidemann commented 3 years ago

It feels like how you can read span should be able to provide this functionality.

axeltidemann commented 3 years ago

I am seeing two possible ways to solve this:

  1. Implement something similar to a dataset factory (used in the Chicago example pipeline) that does balancing on the fly.
  2. Implement a custom component that oversamples the train set and undersamples the eval and test sets. The output artifacts (Examples) are then used by downstream components to train and evaluate.

Option 1 would be more elegant and less wasteful in terms of storage, but option 2 probably easier to implement, and there is no need to touch any of the framework.

Any questions or comments would be appreciated.

kindalime commented 3 years ago

Hello! I would like to tackle this issue.

rcrowe-google commented 3 years ago

Could I suggest that this be created as a TFX-Addons project, so that the wider community can benefit?

casassg commented 3 years ago

Idea for 1:

axeltidemann commented 3 years ago

This is how I implemented option 2:

import os

import tensorflow_data_validation as tfdv
import tensorflow_transform as tft
from tfx.dsl.component.experimental.annotations import OutputDict
from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component
from tfx.types.standard_artifacts import Examples
from tfx.types.standard_artifacts import Schema
from tfx.types.standard_artifacts import ExampleStatistics
from tfx.types.standard_artifacts import ExternalArtifact
from tfx.types import artifact_utils
import tensorflow as tf
import tfx
import numpy as np

def get_single_file(path):
    files = tf.io.gfile.listdir(path)
    assert len(files) == 1, f'{path} has more then 1 file: {files}'
    return os.path.join(path, files.pop())

def read_data(examples, schema, split):
    raw_schema = tfdv.load_schema_text(get_single_file(schema.uri))
    parsed_schema = tft.tf_metadata.schema_utils.schema_as_feature_spec(raw_schema).feature_spec

    def decode(record_bytes):
        return tf.io.parse_single_example(record_bytes, parsed_schema)

    uri = tfx.types.artifact_utils.get_split_uri([examples], split)
    dataset = tf.data.TFRecordDataset(tf.data.Dataset.list_files(f'{uri}/*'),
                                      compression_type='GZIP').map(decode)

    return dataset, parsed_schema

def write_data(dataset, out_examples, split, parsed_schema, filename):

    def _bytes_feature(value):
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.squeeze(value).numpy()]))

    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    func_mapper = {
        tf.int64: _int64_feature,
        tf.float32: _float_feature,
        tf.string: _bytes_feature
    } 

    # To make absolute sure of the ordering, since new dicts are presented all the time.
    keys = parsed_schema.keys()

    def serialize(*args):
        feature = { key: func_mapper[tensor.dtype](tensor.numpy())
                    for key, tensor in zip(keys, args) }

        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    def tf_serialize(x):
        tensors = [ x[key] for key in keys ]

        return tf.py_function(serialize,
                              tensors,
                              tf.string)

    dataset = dataset.map(tf_serialize)

    # TODO: sharding. Make guesses about file size. Should be at least in the 10MB range each.
    uri = tfx.types.artifact_utils.get_split_uri([out_examples], split)
    path = f'{uri}/{filename}.gz'
    writer = tf.data.experimental.TFRecordWriter(path, compression_type='GZIP')
    writer.write(dataset)

@component
def Balancer(
        examples: InputArtifact[Examples],
        schema: InputArtifact[Schema],
        statistics: InputArtifact[ExampleStatistics],
        balanced_examples: OutputArtifact[Examples],
        column: Parameter[str]
        ) -> None:

    splits_list = artifact_utils.decode_split_names(
        split_names=examples.split_names)

    balanced_examples.split_names = artifact_utils.encode_split_names(
        splits=splits_list)

    for split in splits_list:

        uri = tfx.types.artifact_utils.get_split_uri([statistics], split)
        stats = tfdv.load_stats_binary(get_single_file(uri))

        for dataset in stats.datasets:
            for feature in dataset.features:
                if feature.path.step == [column]:
                    for histogram in feature.num_stats.histograms:
                         if histogram.type == histogram.HistogramType.STANDARD:
                             print(histogram)
                             sample_counts = [ bucket.sample_count for bucket in histogram.buckets ]
                             original_size = feature.num_stats.common_stats.tot_num_values

        max_count = max(sample_counts)
        max_category = np.argmax(sample_counts) 
        min_count = min(sample_counts)
        min_category = np.argmin(sample_counts)
        n_categories = len(sample_counts)

        print(f'Biggest category: {max_category}, count: {max_count}, smallest category: {min_category}, count: {min_count}')

        new_oversampled_size = int(max_count*n_categories)
        new_undersampled_size = int(min_count*n_categories)
        oversampled_size_increase = new_oversampled_size/original_size
        undersampled_size_decrease = new_undersampled_size/original_size

        dataset, parsed_schema = read_data(examples, schema, split)

        targets_only = dataset.map(lambda x: tf.squeeze(x[column]))

        uniques = targets_only.apply(tf.data.experimental.unique())
        datasets = []

        for u in uniques:
            print(f'Filtering class {u}')
            datasets.append(dataset.filter(lambda x: tf.squeeze(x[column]) == u).repeat())

        weights = np.ones(n_categories)/n_categories
        sampled = tf.data.experimental.sample_from_datasets(datasets, weights)

        if 'train' in split:
            print(f'{split}: size increase from {original_size} to {new_oversampled_size} '
                  f'({oversampled_size_increase:.1f} times)')
            sampled = sampled.take(new_oversampled_size)            
        else:
            print(f'{split}: size decrease from {original_size} to {new_undersampled_size} '
                  f'({100*undersampled_size_decrease:.1f}%)')
            sampled = sampled.take(new_undersampled_size)

        write_data(sampled, balanced_examples, split, parsed_schema, 'balanced')

It is used in the following way, provided that there is a Transform component by the name of transform earlier in the pipeline:

transformed_statistics_gen = StatisticsGen(examples=transform.outputs.transformed_examples)
transformed_statistics_gen.id = 'statisticsgen-transformed-data'

transformed_schema_gen = SchemaGen(
  statistics=transformed_statistics_gen.outputs.statistics, infer_feature_shape=True)
transformed_schema_gen.id = 'schemagen-transformed-data'

# Balances training data (oversampling) and eval/test data (undersampling)                                                                                                                                                                                  
balancer = Balancer(examples=transform.outputs.transformed_examples,
                      schema=transformed_schema_gen.outputs.schema,
                      statistics=transformed_statistics_gen.outputs.statistics,
                      column=TARGET-TO-BALANCE-ON)

where TARGET-TO-BALANCE-ON is an integer feature. The output can then be fed into the Trainer component like this:

Trainer(...
examples=balancer.outputs.balanced_examples,
...)

Comments, criticisms are welcome. And I hope it can be of use to others, I like the TFX-Addons idea.

rcrowe-google commented 3 years ago

This is currently a proposed project in TFX-Addons.

kindalime commented 3 years ago

Hello! I've completed my own initial implementation of the project, which you can find here. I know that @axeltidemann has completed their own version of the undersampling component above; because our versions use different backends (yours uses pure TF while mine uses an Apache Beam pipeline), would you be open to somehow merging our implementations?

axeltidemann commented 3 years ago

@kindalime I think your approach by using Apache Beam is definitely the way forward. My component does oversampling of the train set and undersampling of the test and eval sets through the use of tf.data.experimental.sample_from_datasets(datasets, weights) since all datasets are repeated() indefinitely. I am not entirely sure if your question was directed at me or @rcrowe-google, but I unfortunately cannot contribute in the immediate future for merging our approaches, due to priorities. But I will for sure contribute with comments, and maybe even coding later down the line.

kindalime commented 3 years ago

Just a quick update: a finished initial version of the component can be found here.

mshearer0 commented 3 years ago

In my toy project I choose option 1, down sampling the training dataset in a follow on component to ExampleGen based on https://blog.tensorflow.org/2020/01/creating-custom-tfx-component.html. As a TFX beginner I had to fully parse the ExampleGen dataset to access the label and then rebuild it - I’m sure there must be a better way.

rcrowe-google commented 3 years ago

@mshearer0 Did you evaluate the Addons component for sampling for this? Is there some reason why that didn't work for you?

mshearer0 commented 3 years ago

Thanks @rcrowe-google - I wasn’t aware of this Addon until a few days ago so I’ll look into it. The project is really just a vehicle for me to learn TFX so appreciate the guidance.

singhniraj08 commented 1 year ago

@RossKohler,

Please try using Sampler component from TFX add-ons for random sampling of data. Thank you!

github-actions[bot] commented 1 year ago

This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 1 year ago

This issue was closed due to lack of activity after being marked stale for past 7 days.