tensorflow / model-remediation

Model Remediation is a library that provides solutions for machine learning practitioners working to create and train models in a way that reduces or eliminates user harm resulting from underlying performance biases.
https://www.tensorflow.org/responsible_ai/model_remediation?hl=en
Apache License 2.0
42 stars 19 forks source link

TF Dataset to TF Examples List in FDW Utils: Can only handle flat datasets #30

Open CLSchmitz opened 1 year ago

CLSchmitz commented 1 year ago

The tf_dataset_to_tf_examples_list function in fdw utils here can only handle datasets where each element is just a neat single-layer dict of format {feature_name: tf.Tensor}. The easiest way to generate one of these is from a dataframe, using eg. tf.data.Dataset.from_tensor_slices(dict(df)).

Specifically, this means it fails at handling tf.data.Datasets that have one of two properties:

import tensorflow_datasets as tfds
from tensorflow_model_remediation.experimental import fair_data_reweighting as fdw

ds = tfds.load('celeb_a')
ex = fdw.utils.tf_dataset_to_tf_examples_list(ds['train'])
next(ex)

throws AttributeError: 'dict' object has no attribute 'numpy'.

import tensorflow_datasets as tfds
from tensorflow_model_remediation.experimental import fair_data_reweighting as fdw

ds = tfds.load('diamonds', as_supervised = True)
ex = fdw.utils.tf_dataset_to_tf_examples_list(ds['train'])
next(ex)

throws AttributeError: 'tuple' object has no attribute 'items'.