ray-project / ray_lightning

Pytorch Lightning Distributed Accelerators using Ray
Apache License 2.0
211 stars 34 forks source link

Can not checkpoint and log #228

Open lcaquot94 opened 1 year ago

lcaquot94 commented 1 year ago

The documentations says that when using with Ray Client, you must disable checkpointing and logging for your Trainer by setting checkpoint_callback and logger to False. So how can we log and save model during training ?

bparaj commented 1 year ago

I have been doing this:

  1. import TuneReportCheckpointCallback from ray_lightning

    from ray_lightning.tune import TuneReportCheckpointCallback
  2. Disable checkpointing with "enable_checkpointing": False, in the pl Trainer's configuration

  3. Initialize logger:

    tb_logger = pl_loggers.TensorBoardLogger(save_dir="/tmp/some-dir")
  4. Initialize tuning strategy

    from ray_lightning import RayStrategy
    strategy = RayStrategy(num_workers=1, num_cpus_per_worker=1, use_gpu=True)
  5. Initialize trainer:

    trainer = pl.Trainer(
        **trainer_config,
        callbacks=[TuneReportCheckpointCallback({"accuracy": "accuracy"}, on="epoch_end")],
        strategy=strategy,
        logger=tb_logger
    )