artemmavrin / focal-loss

TensorFlow implementation of focal loss
https://focal-loss.readthedocs.io
Apache License 2.0
186 stars 43 forks source link

SparseCategoricalFocalLoss cannot compile with TPUs #19

Open lmassaron opened 2 years ago

lmassaron commented 2 years ago

I am trying running a model with SparseCategoricalFocalLoss: it works fine on GPUs but it cannot compile on TPUs, is this a known issue?

You can see the code here: https://www.kaggle.com/lucamassaron/tfkeras-dnn-with-multiclass-focal-loss

artemmavrin commented 2 years ago

Thanks for bringing this up, I haven't seen this kind of issue before. Did you confirm that the error is specific to SparseCategoricalFocalLoss? I.e., does the error occur when using a loss like keras.losses.SparseCategoricalCrossentropy?

lmassaron commented 2 years ago

Yes, if I replace it with a tf.keras.losses.SparseCategoricalCrossentropy() the code just works fine.

Here is the complete error reported by the code:

InternalError                             Traceback (most recent call last)
/tmp/ipykernel_42/3395278551.py in <module>
     45             callbacks=[
     46                 early_stopping,
---> 47                 reduce_lr
     48             ]
     49         )

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1103               logs = tmp_logs  # No error, now safe to assign to logs.
   1104               end_step = step + data_handler.step_increment
-> 1105               callbacks.on_train_batch_end(end_step, logs)
   1106               if self.stop_training:
   1107                 break

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self, batch, logs)
    452     """
    453     if self._should_call_train_batch_hooks:
--> 454       self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
    455 
    456   def on_test_batch_begin(self, batch, logs=None):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _call_batch_hook(self, mode, hook, batch, logs)
    294       self._call_batch_begin_hook(mode, batch, logs)
    295     elif hook == 'end':
--> 296       self._call_batch_end_hook(mode, batch, logs)
    297     else:
    298       raise ValueError('Unrecognized hook: {}'.format(hook))

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _call_batch_end_hook(self, mode, batch, logs)
    314       self._batch_times.append(batch_time)
    315 
--> 316     self._call_batch_hook_helper(hook_name, batch, logs)
    317 
    318     if len(self._batch_times) >= self._num_batches_for_timing_check:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _call_batch_hook_helper(self, hook_name, batch, logs)
    354       hook = getattr(callback, hook_name)
    355       if getattr(callback, '_supports_tf_logs', False):
--> 356         hook(batch, logs)
    357       else:
    358         if numpy_logs is None:  # Only convert once.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self, batch, logs)
   1018 
   1019   def on_train_batch_end(self, batch, logs=None):
-> 1020     self._batch_update_progbar(batch, logs)
   1021 
   1022   def on_test_batch_end(self, batch, logs=None):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _batch_update_progbar(self, batch, logs)
   1082     if self.verbose == 1:
   1083       # Only block async when verbose = 1.
-> 1084       logs = tf_utils.to_numpy_or_python_type(logs)
   1085       self.progbar.update(self.seen, list(logs.items()), finalize=False)
   1086 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in to_numpy_or_python_type(tensors)
    512     return t  # Don't turn ragged or sparse tensors to NumPy.
    513 
--> 514   return nest.map_structure(_to_single_numpy_or_python_type, tensors)
    515 
    516 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
    657 
    658   return pack_sequence_as(
--> 659       structure[0], [func(*x) for x in entries],
    660       expand_composites=expand_composites)
    661 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
    657 
    658   return pack_sequence_as(
--> 659       structure[0], [func(*x) for x in entries],
    660       expand_composites=expand_composites)
    661 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
    508   def _to_single_numpy_or_python_type(t):
    509     if isinstance(t, ops.Tensor):
--> 510       x = t.numpy()
    511       return x.item() if np.ndim(x) == 0 else x
    512     return t  # Don't turn ragged or sparse tensors to NumPy.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in numpy(self)
   1069     """
   1070     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
-> 1071     maybe_arr = self._numpy()  # pylint: disable=protected-access
   1072     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
   1073 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _numpy(self)
   1037       return self._numpy_internal()
   1038     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1039       six.raise_from(core._status_to_exception(e.code, e.message), None)  # pylint: disable=protected-access
   1040 
   1041   @property

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InternalError: 9 root error(s) found.
  (0) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_succeeded_assert/_12573499378960158217/_5/_185]]
  (1) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_succeeded_assert/_12573499378960158217/_5/_245]]
  (2) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_succeeded_assert/_12573499378960158217/_5/_209]]
  (3) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_succeeded_assert/_12573499378960158217/_5/_257]]
  (4) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[cluster_train_function/control_after/_1/_297]]
  (5) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_succeeded_assert/_12573499378960158217/_5/_233]]
  (6) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_succeeded_assert/_12573499378960158217/_5/_197]]
  (7) Internal: {{function_node __inference_train_function_4956}} Compilation failure: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:715) dynamic_reshape->operand(i)->shape().element_type() == S32 
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_12573499378960158217/_5}}]]
     [[tpu_compile_ ... [truncated]

I also tried removing the callbacks, but the error persists on appearing.

artemmavrin commented 2 years ago

Thanks for sharing the traceback. Doing some Googling, it seems like others have encountered similar errors when trying to use "unsupported" ops on TPUs. The error seems to suggest that the resizing ops used in the focal loss code might have something to do with it, although there are no actual line numbers pointing to focal loss source to nail down the problematic ops. Not very helpful, but that's all I know for now. Sorry about that.

lmassaron commented 2 years ago

So probably is it necessary just to change the resizing operations?

chrischang80 commented 2 years ago

I guess the problem may from tf.gather func. If I replace the below code snippet, and return the crossentropy loss directly. Then, everything is fine under tpu env. It just degenerate to the SparseCategoricalCrossEntropy func of keras version.

https://github.com/artemmavrin/focal-loss/blob/7a1810a968051b6acfedf2052123eb76ba3128c4/src/focal_loss/_categorical_focal_loss.py#L173-L177

=> loss = xent_loss

KyloRen1 commented 2 years ago

Are there any updates regarding the issue?