Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.96k stars 3.35k forks source link

`save_hyperparameters` does not save the selected` learning_rate` when `auto_lr_find` is used #15928

Open iceboundflame opened 1 year ago

iceboundflame commented 1 year ago

Bug description

I am running my trainer with the auto_lr_find option, but the model's saved learning_rate reflects the original learning rate and not the final one selected by trainer.tune(). This is because save_hyperparameters() needs to be called at model construction, however it then cannot save the final selected learning rate. This feels like a shortcoming - is there a recommended workaround here? I didn't see one in the docs. Thanks!

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 1.10): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

No response

cc @awaelchli @borda @Blaizzy

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

Palzer commented 10 months ago

Was there a solution for this?

harryseely commented 10 months ago

Was there a solution for this?

Hi @Palzer I am using wandb (but I think you could apply similar code for other loggers). Here is the solution I came up with:

`def run_lr_finder(cfg, trainer, model, data_module, wandb_logger):

Implement auto lr finder

tuner = Tuner(trainer)
model.hparams['lr'] = cfg['lr']  # Ensure model has lr hyperparameter so it can be updated
lr_finder = tuner.lr_find(model, datamodule=data_module)

# Only log new lr on main device (0)
if torch.cuda.current_device() == 0:
    new_lr = lr_finder.suggestion()
    print(f"LR finder identified:\n{new_lr}\nas candidate learning rate")

    fig = lr_finder.plot(suggest=True)
    plot_fpath = "temp/lr_finder_plot.jpg"
    plt.savefig(plot_fpath)
    fig.show()

    # Save plot as an artifact in wandb
    if wandb_logger is False:
        pass
    else:
        wandb_logger.log_image(key="lr_finder_plot", images=[plot_fpath])

        # Set new lr hp in wandb config
        wandb.config.update({"lr": new_lr}, allow_val_change=True)`
Ritesh313 commented 1 week ago

This still doesn't update automatically when using comet logger. I have learning_rate parameter in the Module definition, it gets updated in the object, but comet still shows the initial learning rate that was passed when creating that object.

# Class definition:
class RGB_Main_Module(L.LightningModule):
    def __init__(self, model, learning_rate):
        ...
        self.save_hyperparameters(ignore=['model'])
    ....

# creating an object
m = RGB_Main_Module(model=model, learning_rate = 0.0001)
tuner = Tuner(trainer=trainer)
lr_finder = tuner.lr_find(model=m, datamodule=data_module, num_training=1000)

Should I use save_hyperparmeters somewhere else or can i use some other command to update the learning_rate value on comet? Currently I'm just logging the new learning_rate explicitly under another name.