tensorflow / tfx

TFX is an end-to-end platform for deploying production ML pipelines
https://tensorflow.org/tfx
Apache License 2.0
2.11k stars 708 forks source link

Tuning the batch size in the Tuner component #3591

Closed jimzer closed 3 years ago

jimzer commented 3 years ago

Hello,

I'm currently using TFX to build a pipeline on the Google AI platform with the Kubeflow engine. I have a model where the batch size is an important hyper-parameter to tune.

I would like to search this hyper-parameter in the Tuner component.

Is it even possible?

I follow the TFX example with the Penguin dataset, more precisely the tuner component implementation: found here.

The _get_hyperparameters function returns the sample space for the model hyperparameters (see line 139). However, the batch size to train the model is fixed and specified at the end of the tuner_fn (see line 246).

Is there a way to dynamically change the batch size based on a sample from the hyper-parameter space?

Thanks for your help !

1025KB commented 3 years ago

Hi,

The Tuner currently only support tuning the hparams of the model for best metrics (e.g., accuracy)

batch_size is outside of the model tuning loop, and it's for performance instead of accuracy. Thus the existing tuning logic (depends on kerastuner library) won't be able to handle it.

(batch_size is more like a TFX pipeline configuration tuning loop instead of model hparam tuning loop)

1025KB commented 3 years ago

Customize kerastuner.BaseTuner should work for batch size tuning

You'd pass an unbatched dataset to the Tuner, and in the run_trial method of the CustomTuner, you'd batch the dataset with a variable size drawn from hp:

  1. CustomTuner with a run_trial
  2. In run trial, get the batch_size from trial
  3. Update dataset with correct batch_size
  4. train the model with updated dataset
  5. you might need to update how the metrics is calculated if you want to add execution time as one of the evaluation method for trial
  6. in your tuner_fn you need to change the fit_args to match the run_trial params (pass in unbatched dataset)
jimzer commented 3 years ago

Thanks for your help! I will give a try to the Keras BaseTuner customization.

arghyaganguly commented 3 years ago

@jimzer ,please close this if you are satisfied with the solution provided by @1025KB.Thanks.

google-ml-butler[bot] commented 3 years ago

Are you satisfied with the resolution of your issue? Yes No