ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.13k stars 5.61k forks source link

[Tune] Allow use of tune.search.Repeater with BasicVariantGenerator to simplify K-fold CrossValidation #33677

Open adivekar-utexas opened 1 year ago

adivekar-utexas commented 1 year ago

Description

From this discussion: https://discuss.ray.io/t/raytune-use-repeater-with-basicvariantgenerator/9042

Currently, you can't use a BasicVariantGenerator with a Repeater.

Use case

This should be supported, so that something like this works:

Repeater(BasicVariantGenerator(), repeat=num_folds)

Repeater is one suggested way to implement K-fold crossvalidation in Ray. K-fold CV is a common ask in Ray Tune:

adivekar-utexas commented 1 year ago

BTW, I have successfully used Repeater to do K-Fold as of yesterday. So it's definitely possible.

joshuasv commented 1 year ago

@adivekar-utexas Could you share the implementation?

CMGeldenhuys commented 1 year ago

I've managed to get K-fold cross validation working with TorchTrainer and the BasicVariantGenerator. The approach I took involves using the constant_grid_search=True parameter of BasicVariantGenerator.

...

  search_space = {
      'kfold': tune.grid_search([1, 2, 3, 4, 5]), # search over k-fold with 5 folds
      # ... other hyper parameters
  }

  tuner = tune.Tuner(
      trainer, # instance of TorchTrainer
      param_space={'train_loop_config':  search_space},
      tune_config = tune.TuneConfig(
          scheduler=scheduler,
          search_alg=BasicVariantGenerator(
              constant_grid_search=True, # Required for kfolds
          ),
      )
  )

...

In the train_loop_per_worker function of trainer I handle the actual cross-validation logic. But this at least ensures that the hyperparameters remain constant over each fold and that each fold can be executed in parallel. Not sure of how much help this is to anyone else, as it is fairly implementation-specific.

zcarrico-fn commented 12 months ago

BTW, I have successfully used Repeater to do K-Fold as of yesterday. So it's definitely possible.

@adivekar-utexas , Would you please share how you got Repeater working with BasicVariantGenerator? Thank you!