keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

Remove the `training` argument from `.call()` function of all preprocessing layers #1659

Closed haifeng-jin closed 1 year ago

haifeng-jin commented 1 year ago

According to the discussion with @LukeWood and @fchollet , we will remove the training argument from all .call() functions of all preprocessing layers in KerasCV.

The reason is to easily build the preprocessing pipeline into a Sequential model and run it with tf.data.Dataset.map(). This preprocessing pipeline is separate from the main neural network model, thus do not receive the training argument when calling model.fit() in the main neural network model. Therefore, managing this training argument in the preprocessing pipeline is an extra burden on the users.

After removing the training argument, all preprocessing layers are expected only to perform training time behaviors all the time. The user should construct separate pipelines for training and inferencing.

Following are the preprocessing layers in KerasCV that overrides the call() function, which needs to be changed in addressing this issue.

keras_cv/layers/preprocessing/

keras_cv/layers/preprocessing_3d/

haifeng-jin commented 1 year ago

I only included the layers in keras_cv/layers/preprocessing/ and keras_cv/layers/preprocessing_3d/.

@LukeWood Would you help confirm the following?

  1. The description above is accurate.
  2. There is no preprocessing layers in other directories.
  3. These preprocessing layers are extended from BaseRandomLayer, which is not a subclass of KPL. That doesn't matter.

Thanks.

LukeWood commented 1 year ago

Hey @haifeng-jin ! This looks correct to me - I think everything here is accurate.

The reason is to easily build the preprocessing pipeline into a Sequential model and run it with tf.data.Dataset.map(). This preprocessing pipeline is separate from the main neural network model, thus do not receive the training argument when calling model.fit() in the main neural network model. Therefore, managing this training argument in the preprocessing pipeline is an extra burden on the users.

an extra piece of info here is that trying to reason about what "JitteredResize" should do at inference time is sort of a non-sense concept. The optimal action for classification (which is to use Resizing(crop_to_aspect_ratio=True)), is not the same as the optimal action for object detection (which is Resizing(pad_to_aspect_ratio=True)). Due to this, and the general confusion about what a KPL/augmentation layer "should" do at inference time is overly confusing, and thus we are going to just make each layer do one thing.

Also minimizes maintenance burden & bug surfaces by removing the if training: branch.

haifeng-jin commented 1 year ago

@soma2000-lang Yes, but I will try to do one of the layers first to rule out any possible caveats and set it as an example PR. I will ping you again when it ready for you to take over. Thank you!

haifeng-jin commented 1 year ago

@soma2000-lang I have merged my PR #1664. Seems the task has no special caveat.

Would you like to take over the issue from here? Thanks!

soma2000-lang commented 1 year ago

Sure @haifeng-jin

soma2000-lang commented 1 year ago

Resize.py Done in #1672 randomly_zoomed_crop done in https://github.com/keras-team/keras-cv/pull/1673 Removing training call from mosaic.py #1674 .Mentioning so that if some other contributor wants to contribute, the same thing does not get done twice

james77777778 commented 1 year ago

Hi @haifeng-jin @soma2000-lang

I can work on base_image_augmentation_layer.py and vectorized_base_image_augmentation_layer.py.

However, I have a few questions:

  1. Should I only remove training argument and related logic, or it is better to remove entire call function?
  2. If I need to remove the entire call function, how can I retain the logic of _ensure_inputs_are_compute_dtype, _format_inputs and _format_output and ensure that the inputs are passed to the main preprocessing functions like _batch_augment and _augment?

Kindly ping @LukeWood @ianstenbit

soma2000-lang commented 1 year ago

@james77777778 I don't think in these 2 cases the entire call function can be removed ,only the training arguments.However I am ofcouse waiting for the inputs from @LukeWood @haifeng-jin and @ianstenbit

ID6109 commented 1 year ago

I'd like to contribute to the 3 files in preprocessing_3d if that's okay @haifeng-jin @soma2000-lang

soma2000-lang commented 1 year ago

Yes sure