Open shrave opened 1 year ago
Hello @shrave,
Thank you for kind words. I am happy to hear that this repo is useful for your research.
As to the issue, I have never really used ray
but from what I understand from the provided tutorials, I think that ray-tune
can be included pretty straightforwardly using a callback as you suggest.
from pythae.trainers.training_callbacks import TrainingCallback
class RayCallback(TrainingCallback):
def __init__(self) -> None:
super().__init__()
def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs):
metrics = kwargs.pop("metrics") # get the metrics during training
tune.report(eval_epoch_loss=metrics["eval_epoch_loss"]) # add the metric to monitor in the report
You will need to wrap the training part of your script in a method that will then be called by the ray
Tuner
. The input config
is expected to be the search_space
dictionary defining the range of the hyper-parameters considered. All the rest is similar to a classic training configuration and launching with pythae
.
def train_ray(config):
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)
train_dataset = BaseDataset(mnist_trainset.data[:1000].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
eval_dataset = BaseDataset(mnist_trainset.data[-1000:].reshape(-1, 1, 28, 28) / 255., torch.ones(1000))
my_training_config = BaseTrainerConfig(
output_dir='my_model',
num_epochs=50,
learning_rate=config["lr"], # pass the lr for hp search
per_device_train_batch_size=200,
per_device_eval_batch_size=200,
steps_saving=None,
optimizer_cls="AdamW",
optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
scheduler_cls="ReduceLROnPlateau",
scheduler_params={"patience": 5, "factor": 0.5}
)
my_vae_config = model_config = VAEConfig(
input_dim=(1, 28, 28),
latent_dim=10
)
my_vae_model = VAE(
model_config=my_vae_config
)
# Add the ray callback to the callback list
callbacks = [RayCallback()]
trainer = BaseTrainer(
my_vae_model,
train_dataset,
eval_dataset,
my_training_config,
callbacks=callbacks # pass the callbacks to the trainer
)
trainer.train() # launch the training
search_space = {
"lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
}
tuner = tune.Tuner( train_ray, tune_config=tune.TuneConfig( num_samples=20, scheduler=ASHAScheduler(metric="eval_epoch_loss", mode="min"), ), param_space=search_space, )
results = tuner.fit()
I have opened #87 since some minor changes should be added to the current implementation of the `BaseTrainer` to be able to read the metrics at the end of each epoch. Let me know if this is the behavior you are expecting :) In particular, you can look at this [script example](https://github.com/clementchadebec/benchmark_VAE/blob/ray_integration/examples/scripts/hp_tuning_with_ray.py).
Do not hesitate, if you have any questions.
Best,
Clément
Hi,
I was wondering if I could include the ray tune (hyper-parameter search) library as either a callback or in the base trainer class to look for the right hyper-parameters for a model and even stop early.
Can you please tell me how it is possible to integrate it and thereby stop the training midway in case a particular hyper-parameter configuration does not give good performance?
Even if you could suggest a way to just return the logger at every epoch when a training pipeline instance is called, then my job would be done.
This library has been extremely useful in my research. Thank you very much!