huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.63k stars 26.2k forks source link

TFDebertaModel and TFDebertaV2Model throws TypeError when keras.fit with Mixed Precision #31989

Closed pinesnow72 closed 3 weeks ago

pinesnow72 commented 1 month ago

System Info

Who can help?

@ArthurZucker, @Rocketknight1

Information

Tasks

Reproduction

I am trying to fine-tune TFDebertaModel and TFDebertaV2Model for NER task with setting mixed precision

policy = keras.mixed_precision.Policy('mixed_float16')
keras.mixed_precision.set_global_policy(policy)

model = TFDebertaModel.from_pretrained('deberta-base')
# or 
# model = TFDebertaV2Model.from_pretrained('deberta-v3-base')

....

model.fit(x=train_data, validation_data=valid_data, epochs=10)

However, when training this model, TypeError was thrown in TFDebertaEmbeddings like the followings: TypeError: Exception encountered when calling layer 'embeddings' (type TFDebertaEmbeddings). in user code: File "/home/swlee/miniconda3/envs/tf216/lib/python3.12/site-packages/transformers/models/deberta/modeling_tf_deberta.py", line 929, in call final_embeddings = final_embeddings mask TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float16 of argument 'x'.

The case of TFDebertaV2Model was same with this. With mixed precision, TF and Keras requires to use Layer.dtype for model or layer's weights and Layer.compute_dtype for internal tensor computation. But the current TFDebertaModel and TFDebertaV2Model codes do not seem to reflect this requirement and definitely assume the dtype would be tf.float32

Expected behavior

I hope that this bug could be fixed soon to support mixed precision. Actually, I tried to search and correct some error-prone code snippets in modeling_tf_deberta.py and modeling_tf_deberta_v2.py. Here is the list (but, I am not sure this is exhausted):

[in modeling_tf_deberta.py]

(lines: 105, 106)

  output = tf.where(rmask, float("-inf"), inputs)
  output = stable_softmax(output, self.axis)

(correction would be)

  output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)  # mixed precision # float("-inf")
  output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)  # mixed precision # output

(lines: 133, 135, 139)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
  inputs = tf.where(mask, 0.0, inputs) * scale
  return tf.where(mask, 0.0, upstream) * scale

(correction would be)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)  # mixed precision # dtype=tf.float32)
  inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale  # mixed precision # 0.0
  return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale  # mixed precision # 0.0

(lines: 705, 707)

  qkvw = tf.TensorArray(dtype=tf.float32, size=3)
  qkvw_inside = tf.TensorArray(dtype=tf.float32, size=self.num_attention_heads)

(correction would be)

  qkvw = tf.TensorArray(dtype=self.dtype, size=3)  # mixed precision # tf.float32
  qkvw_inside = tf.TensorArray(dtype=self.dtype, size=self.num_attention_heads)  # mixed precision # tf.float32

(lines: 799)

  pos_query_layer /= tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=tf.float32))

(correction would be)

  pos_query_layer /= tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32

(lines: 927)

  mask = tf.cast(tf.expand_dims(mask, axis=2), tf.float32)

(correction would be)

  mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)

[in modeling_tf_deberta_v2.py]

(lines: 106, 107)

  output = tf.where(rmask, float("-inf"), inputs)
  output = stable_softmax(output, self.axis)

(correction would be)

  output = tf.where(rmask, tf.cast(float("-inf"), dtype=self.compute_dtype), inputs)  # mixed precision # float("-inf")
  output = stable_softmax(tf.cast(output, dtype=tf.float32), self.axis)  # mixed precision # output

(lines: 135, 137, 141)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=tf.float32)
  inputs = tf.where(mask, 0.0, inputs) * scale
  return tf.where(mask, 0.0, upstream) * scale

(correction would be)

  scale = tf.convert_to_tensor(1.0 / (1 - self.drop_prob), dtype=self.compute_dtype)  # mixed precision # dtype=tf.float32)
  inputs = tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), inputs) * scale  # mixed precision # 0.0
  return tf.where(mask, tf.cast(0.0, dtype=self.compute_dtype), upstream) * scale  # mixed precision # 0.0

(lines: 391, 404)

  out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
  input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), tf.float32)

(correction would be)

  out = tf.where(tf.broadcast_to(tf.expand_dims(rmask, -1), shape_list(out)), 0.0, out)
  input_mask = tf.cast(tf.expand_dims(input_mask, axis=2), dtype=self.compute_dtype)  # mixed precision # tf.float32)

(lines: 770)

  scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32))

(correction would be)

  scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32))

(lines: 853, 867)

  scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, tf.float32))
  scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, tf.float32))

(correction would be)

  scale = tf.math.sqrt(tf.cast(shape_list(pos_key_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32))
  scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, dtype=self.compute_dtype))  # mixed precision # tf.float32))

(lines: 1034)

  mask = tf.cast(tf.expand_dims(mask, axis=2), tf.float32)

(correction would be)

  mask = tf.cast(tf.expand_dims(mask, axis=2), dtype=self.compute_dtype)  # mixed precision # tf.float32)
amyeroberts commented 1 month ago

Hi @pinesnow72, thanks for raising an issue!

As you're trying something with custom code, specifically training with mixed precision, this is a question best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

cc @Rocketknight1

ArthurZucker commented 1 month ago

(But @pinesnow72 your intuition is correct, here if we create the -inf in the mask based on float dtype it is problematic. Feel free to open a PR with your proposed changes, I am certain @Rocketknight1 will be able to review !)

Rocketknight1 commented 1 month ago

Yes, agreed - if you're willing to open the PR, I think it would be a good change!