ltatzel / PyTorchHessianFree

PyTorch implementation of the Hessian-free optimizer
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link

Can't save model via the pytorch-lighting ModelCheckpoint() #7

Closed youli-jlu closed 3 months ago

youli-jlu commented 3 months ago

Hi ltatzel ,

Your Hessian-Free LM optimizer performs very well, and I want to use it to replace the pytorch L-BFGS optimizer. However, the model can't be saved normally if I use the ModelCheckpoint() in pytorch-lighting. You can find my test python file in the attachment. Could you give me some suggestions to handle this problem?

Thanks.

import numpy as np
import pandas as pd
import time
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt

import lightning as L
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from hessianfree.optimizer import HessianFree

class LitModel(LightningModule):
    def __init__(self,loss):
        super().__init__()
        self.tanh_linear= nn.Sequential(
                nn.Linear(1,20),
                nn.Tanh(),
                nn.Linear(20,20),
                nn.Tanh(),
                nn.Linear(20,1),
                )
        self.loss_fn = nn.MSELoss()
        self.automatic_optimization = False
        return

    def forward(self, x):
        out = self.tanh_linear(x)
        return out

    def configure_optimizers(self):
        optimizer = HessianFree(
                self.parameters(),
                cg_tol=1e-6,
                cg_max_iter=1000,
                lr=1e0,
                LS_max_iter=1000,
                LS_c=1e-3
                )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        opt = self.optimizers()
        def forward_fn():
            y_pred = self(x)
            loss=self.loss_fn(y_pred,y)
            return loss,y_pred
        opt.optimizer.step( forward=forward_fn)
        loss,y_pred=forward_fn()
        self.log("train_loss", loss, on_epoch=True, on_step=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = self.loss_fn(y_hat, y)
        # passing to early_stoping
        self.log("val_loss", val_loss, on_epoch=True, on_step=False)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        return loss

def main():
    input_size = 20000
    train_size = int(input_size*0.9)
    test_size  = input_size-train_size
    batch_size = 1000

    x_total = np.linspace(-1.0, 1.0, input_size, dtype=np.float32)
    x_total = np.random.choice(x_total,size=input_size,replace=False) #random sampling
    x_train = x_total[0:train_size]
    x_train= x_train.reshape((train_size,1))
    x_test  = x_total[train_size:input_size]
    x_test= x_test.reshape((test_size,1))

    x_train=torch.from_numpy(x_train)
    x_test=torch.from_numpy(x_test)

    y_train = torch.from_numpy(np.sinc(10.0 * x_train))
    y_test  = torch.from_numpy(np.sinc(10.0 * x_test))

    training_data = TensorDataset(x_train,y_train)
    test_data = TensorDataset(x_test,y_test)

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size
            #,num_workers=2
            )
    test_dataloader = DataLoader(test_data, batch_size=batch_size
            #,num_workers=2
            )

    for X, y in test_dataloader:
        print("Shape of X: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break
    for X, y in train_dataloader:
        print("Shape of X: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break

    loss_fn = nn.MSELoss()

    model=LitModel(loss_fn)

    # prepare trainer
    opt_label=f'lm_HF_t20'

    logger = CSVLogger(f"./{opt_label}", name=f"test-{opt_label}",flush_logs_every_n_steps=1)
    epochs = 1e1
    print(f"test for {opt_label}")
    early_stop_callback = EarlyStopping(
            monitor="val_loss"
            , min_delta=1e-9
            , patience=10
            , verbose=False, mode="min"
            , stopping_threshold = 1e-8 #stop if reaching accuracy
            )
    modelck=ModelCheckpoint(
            dirpath = f"./{opt_label}"
            , monitor="val_loss"
            ,save_last = True
            #, save_top_k = 2
            #, mode ='min'
            #, every_n_epochs = 1
            #, save_on_train_epoch_end=True
            #,save_weights_only=True,
            )

    Train_model=Trainer(
            accelerator="cpu"
            , max_epochs = int(epochs)
            , enable_progress_bar = True #using progress bar
            #, callbacks=[modelck,early_stop_callback] # using earlystopping
            , callbacks=[modelck] #do not using earlystopping
            , logger=logger
            #, num_processes = 16
            )

    t1=time.time()
    Train_model.fit(model,train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    t2=time.time()

    print('total time')
    print(t2-t1)

    # torch.save() can save the model, but ModelCheckpoint() can't.
    #torch.save(model.state_dict(), f"model{opt_label}.pth")
    #print(f"Saved PyTorch Model State to model{opt_label}.pth")
    exit()
    return

if __name__=='__main__':
    main()
ltatzel commented 3 months ago

Hi youli-jlu, If I understand correctly, the issue is rather with ModelCheckpoint() than with the optimizer. Since I don't have any experience with the lightning.pytorch package, I don't think I can help you. Sorry!

youli-jlu commented 3 months ago

Thank you for your quick response. I am going to submit an issue to lightning.