keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
797 stars 242 forks source link

Distributed training not working (batch size calculation) #1630

Open natbprice opened 6 months ago

natbprice commented 6 months ago

Describe the bug This is an issue I am having with keras-nlp, but I am not sure if it can be solved here or should be reported under keras or tensorflow.

Currently, the batch size is not calculated correctly when performing multi-worker distributed training with JAX backend:

Traceback (most recent call last):
  File "mycode.py", line 293, in <module>
    history = classifier.fit(
  File "/usr/local/lib/python3.10/dist-packages/keras_nlp/src/utils/pipeline_model.py", line 194, in fit
    return super().fit(
  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/keras/src/distribution/distribution_lib.py", line 467, in distribute_dataset
    raise ValueError(
ValueError: The batch size of the input dataset is unknown. Please config the batch size for the input dataset, e.g via `dataset.batch(batch_size)`

To Reproduce Run (multi-worker?) distributed training with JAX backend.

The issue seems to stem from https://github.com/keras-team/keras-nlp/blob/778ccd72fe5d74e8eedc7d38dfb57561821b7851/keras_nlp/src/utils/pipeline_model.py#L181 where mapping a preprocessor over the dataset leads to failure at https://github.com/keras-team/keras/blob/3105247028bb0a7e6d2f05f5daa44c9cfafd3e67/keras/src/distribution/distribution_lib.py#L465

Here is minimal example where tensorflow.python.data.experimental.ops.distribute.compute_batch_size() returns -1 after mapping:

import tensorflow as tf
from tensorflow.python.data.experimental.ops import distribute as tf_data_distribute
from keras_nlp.src.utils.keras_utils import pack_x_y_sample_weight

ds = tf.data.Dataset.range(8)
ds = ds.batch(3)

print(f"True batch size (before): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (before): {tf_data_distribute.compute_batch_size(ds)}")

ds = ds.map(pack_x_y_sample_weight, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

print(f"True batch size (after): {len(list(ds.as_numpy_iterator()))}")
print(f"Calculated batch size (after): {tf_data_distribute.compute_batch_size(ds)}")

Expected behavior A batched tf.data.Dataset() object is recognized as being batched.

natbprice commented 6 months ago

Here is Colab with reproducible example:

https://colab.research.google.com/drive/1aCJVUNfro68fek-o0i_7Iojl6Qtix6NK?usp=sharing

natbprice commented 6 months ago

I can reproduce the error using just keras, so maybe I should open issue there? Or maybe it should be fixed in tensorflow? But the documentation for tensorflow.python.data.experimental.ops.distribute.compute_batch_size() describes its limitations so not sure it is technically a bug in tensorflow.

https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I#scrollTo=0Hf6qJOxXsqI

natbprice commented 5 months ago

Hi @SuryanarayanaY, the related ticket in keras was closed with the recommendation that this be fixed in keras-nlp. Per @hertschuh: "One should simply apply batch_size after the map and not in _convert_inputs_to_dataset".

I can't quite figure out the best way for this to work with keras-nlp API. In particular, it seems like there are several combinations of (1) distribution strategy, (2) input types (e.g., tf.data.Dataset, NumPy arrays), and (3) batching (e.g., pre-batched dataset, explicit batch_size).

Currently, in _convert_inputs_to_dataset it will raise an error if you attempt to pass a tf.data.Dataset with explicit batch_size argument. It also looks like there is error handling to prevent you from passing unbatched inputs, but the string matching on the error message may be oudated and not functioning.

hertschuh commented 5 months ago

@natbprice ,

Sorry for the delay, I'm still working on this. It turned out to be more complex to fix than I expected.

hertschuh commented 5 months ago

@natbprice ,

I experimented with a few things, but I could not find a fix in keras-nlp that would work in all cases.

However, I do have an easy workaround: ds.batch(8, drop_remainder=True). By doing this, the dataset knows that the first dimension, the batch size, is always 8. Then, it can infer the first dimension of the result of other operations like map.

If you don't do drop_remainder=True, it thinks the last batch may be incomplete. And while you can still retrieve the batch size right after batching, it doesn't propagate through other operations like map.

If you're concerned about not using the last few examples, you can shuffle, or repeat the dataset before batching.

natbprice commented 4 months ago

@hertschuh do you have a working example? This doesn't seem to work for me.

https://colab.research.google.com/drive/1aCJVUNfro68fek-o0i_7Iojl6Qtix6NK?usp=sharing#scrollTo=0Hf6qJOxXsqI

Edit:

It seems like the workaround works outside of keras-nlp. Maybe there is something specific to keras-nlp that still needs to be resolved?

https://colab.research.google.com/drive/1IxVNDcNoIK4SiX2wuDQKfqR_6Or9P40I#scrollTo=0Hf6qJOxXsqI

natbprice commented 4 months ago

@SuryanarayanaY can we reopen this issue please?