tensorflow / model-remediation

Model Remediation is a library that provides solutions for machine learning practitioners working to create and train models in a way that reduces or eliminates user harm resulting from underlying performance biases.
https://www.tensorflow.org/responsible_ai/model_remediation?hl=en
Apache License 2.0
42 stars 19 forks source link

Keyword arguments not supported by original model: `['mask']` #24

Closed matanhalevy closed 2 years ago

matanhalevy commented 3 years ago

Hi I've been trying to debug this for a few days, I'm using the MinDiff remediation similarly to how it's used in the tutorial. Using the debugger I see the 'mask' value is None from the caller. The base model I am using to remediate unfairness is TFBertForSequenceClassification from HuggingFace's transformers library.

Relevant Versions: python 3.6.9 tensorflow 2.3.1 tensorflow-model-remediation 0.1.3 transformers 4.3.2

Stacktrace:

Epoch 1/1000
Traceback (most recent call last):
  File "*/src/benchmarking/run_model.py", line 450, in <module>
    main()
  File "*/src/benchmarking/run_model.py", line 446, in main
    eval_min_diff_bert(**kwargs)
  File "*/src/benchmarking/run_model.py", line 211, in eval_min_diff_bert
    model.fit(dataset, epochs=epochs)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File *\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\def_function.py", line 697, in _initialize
    *args, **kwds))
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\framework\func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\eager\def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "*\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\framework\func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    *\anaconda3\envs\mindiff\lib\site-packages\tensorflow\python\keras\engine\training.py:806 train_function  *
        return step_function(self, iterator)
   *\anaconda3\envs\mindiff\lib\site-packages\tensorflow_model_remediation\min_diff\keras\models\min_diff_model.py:473 call  *
        min_diff_loss = self.compute_min_diff_loss(
   *\anaconda3\envs\mindiff\lib\site-packages\tensorflow_model_remediation\min_diff\keras\models\min_diff_model.py:371 compute_min_diff_loss  *
        predictions = self._call_original_model(x, training=training, mask=mask)
    *\anaconda3\envs\mindiff\lib\site-packages\tensorflow_model_remediation\min_diff\keras\models\min_diff_model.py:234 _call_original_model  *
        return self.original_model(inputs, **kwargs)
    *\anaconda3\envs\mindiff\lib\site-packages\transformers-4.2.2-py3.8.egg\transformers\models\bert\modeling_tf_bert.py:1405 call  *
        inputs = input_processing(
    *\anaconda3\envs\mindiff\lib\site-packages\transformers-4.2.2-py3.8.egg\transformers\modeling_tf_utils.py:345 input_processing  *
        raise ValueError(

    ValueError: The following keyword arguments are not supported by this model: ['mask'].

Code to reproduce (I can't share my data, but I've tried loading the BERT weights from a TF trained Transformers with the same issue. I've also commented out wrapping the TFBertForSequenceClassification in tf.keras model but it does not make a difference to the stacktrace.

bert_model = TFBertForSequenceClassification.from_pretrained(path_to_bert, from_pt=True)
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
# loss = tf.keras.losses.BinaryCrossentropy()
# bert_model.compile(optimizer=optimizer, loss=loss)

min_diff_weight = 1.5 

# Create the dataset that will be passed to the MinDiffModel during training.
dataset = md.keras.utils.input_utils.pack_min_diff_data(
    train_ds_main, train_ds_unpriv, train_ds_priv)

# Wrap the original model in a MinDiffModel, passing in one of the MinDiff
# losses and using the set loss_weight.
min_diff_loss = md.losses.MMDLoss()
model = md.keras.MinDiffModel(bert_model,
                              min_diff_loss,
                              min_diff_weight)

# Compile the model normally after wrapping the original model.  Note that
# this means we use the baseline's model's loss here.
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss = tf.keras.losses.BinaryCrossentropy()
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

model.fit(dataset, epochs=epochs)

Thank you!

matanhalevy commented 3 years ago

Hi is this repo actively maintained/will this issue be assigned?

seano314 commented 3 years ago

Hey Matan,

Sorry for the delay on this. I've looked into this a bit and it seems like it might be a limitation on Huggingface. If I'm not mistaken, Huggingface doesn't allow abstract classes for whatever reason. The specific ValueError that you're hitting was added via #8602 and there is a comment on that line.

I'm happy to help dig into this more if you're able to help me reproduce the error.

seano314 commented 2 years ago

Closing this comment out given the original ValueError in this issue was removed from huggingface.

https://github.com/huggingface/transformers/blob/master/src/transformers/models/albert/modeling_tf_albert.py