clementchadebec / benchmark_VAE

Unifying Variational Autoencoder (VAE) implementations in Pytorch (NeurIPS 2022)
Apache License 2.0
1.78k stars 162 forks source link

Integrate Ray tune with the base trainer class #86

Open shrave opened 1 year ago

shrave commented 1 year ago

Hi,

I was wondering if I could include the ray tune (hyper-parameter search) library as either a callback or in the base trainer class to look for the right hyper-parameters for a model and even stop early.

Can you please tell me how it is possible to integrate it and thereby stop the training midway in case a particular hyper-parameter configuration does not give good performance?

Even if you could suggest a way to just return the logger at every epoch when a training pipeline instance is called, then my job would be done.

This library has been extremely useful in my research. Thank you very much!

clementchadebec commented 1 year ago

Hello @shrave,

Thank you for kind words. I am happy to hear that this repo is useful for your research.

As to the issue, I have never really used ray but from what I understand from the provided tutorials, I think that ray-tune can be included pretty straightforwardly using a callback as you suggest.

  1. The callback can be created as follows. The callback should be able to read the metrics at the end of the epoch and store them in the tune report.
from pythae.trainers.training_callbacks import TrainingCallback

class RayCallback(TrainingCallback):

    def __init__(self) -> None:
        super().__init__()

    def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs):
        metrics = kwargs.pop("metrics") # get the metrics during training
        tune.report(eval_epoch_loss=metrics["eval_epoch_loss"]) # add the metric to monitor in the report
  1. You will need to wrap the training part of your script in a method that will then be called by the ray Tuner. The input config is expected to be the search_space dictionary defining the range of the hyper-parameters considered. All the rest is similar to a classic training configuration and launching with pythae.

    def train_ray(config):
    
    mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)
    
    train_dataset = BaseDataset(mnist_trainset.data[:1000].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
    eval_dataset = BaseDataset(mnist_trainset.data[-1000:].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
    
    my_training_config = BaseTrainerConfig(
       output_dir='my_model',
       num_epochs=50,
       learning_rate=config["lr"], # pass the lr for hp search
       per_device_train_batch_size=200,
       per_device_eval_batch_size=200,
       steps_saving=None,
       optimizer_cls="AdamW",
       optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
       scheduler_cls="ReduceLROnPlateau",
       scheduler_params={"patience": 5, "factor": 0.5}
    )
    
    my_vae_config = model_config = VAEConfig(
       input_dim=(1, 28, 28),
       latent_dim=10
    )
    
    my_vae_model = VAE(
       model_config=my_vae_config
    )
    
    # Add the ray callback to the callback list
    callbacks = [RayCallback()]
    
    trainer = BaseTrainer(
       my_vae_model,
       train_dataset,
       eval_dataset,
       my_training_config,
       callbacks=callbacks # pass the callbacks to the trainer
    )
    
    trainer.train() # launch the training
  2. You can launch the ray tuning
    
    search_space = {
    "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
    }

tuner = tune.Tuner( train_ray, tune_config=tune.TuneConfig( num_samples=20, scheduler=ASHAScheduler(metric="eval_epoch_loss", mode="min"), ), param_space=search_space, )

results = tuner.fit()



I have opened #87 since some minor changes should be added to the current implementation of the `BaseTrainer` to be able to read the metrics at the end of each epoch. Let me know if this is the behavior you are expecting :) In particular, you can look at this [script example](https://github.com/clementchadebec/benchmark_VAE/blob/ray_integration/examples/scripts/hp_tuning_with_ray.py). 

Do not hesitate, if you have any questions.

Best,

Clément