Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

The learning rate adjustment problem when self.automatic_optimization=False #17738

Closed zxyl1003 closed 1 year ago

zxyl1003 commented 1 year ago

Bug description

When I was training the SRGAN network, I wanted to implement training both CNN and GAN. when having multiple optimizers, I wrote the code according to the example in the documentation and configured the interval of lr_scheduler to epoch in the configure_optimizer function, but I found that this did not work and the learning rate remained at batch step as the interval update. In addition to that, when I use opt.zero_grad(), I get a warning :"Reference to 'zero_grad' not found in 'LightningOptimizer | list'". How do I get lr_scheduler to update the learning rate at epoch intervals and eliminate warnings. I show all my code below:

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import importlib
import inspect
import os

import cv2
import numpy as np
import rasterio
import torch
import torch.nn.functional as F
from lightning.pytorch import LightningModule
from torchmetrics.functional import peak_signal_noise_ratio, spectral_angle_mapper, \
    structural_similarity_index_measure, multiscale_structural_similarity_index_measure

from .base_interface import define_optimizers, define_lr_scheduler, init_weights, optimized_linear
from losses import contentloss

class SrInterface(LightningModule):
    def __init__(self, model_name, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.kwargs = kwargs
        self.model = self.load_model(model_name)
        init_weights(self.model, self.hparams.init_type, self.hparams.init_gain)
        if self.hparams.training_frame == 'gan' and self.hparams.dis_name:
            self.discriminator = self.load_model(self.hparams.dis_name)
            init_weights(self.discriminator, self.hparams.init_type, self.hparams.init_gain)
        # 手动更新网络,取消托管
        self.automatic_optimization = False

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        lr, hr = batch['lr_data'], batch['hr_data']
        sr = self(lr)

        if self.hparams.training_frame == 'gan':
            # 获取优化器和学习率调整器
            opt_g, opt_d = self.optimizers()
            scheduler_g, scheduler_d = self.lr_schedulers()
            # -----------------------------------------------训练生成器---------------------------------------------
            # 清空判别器梯度
            for d_param in self.discriminator.parameters():
                d_param.requires_grad = False
            # L1 损失与 VGG19 感知损失
            pixel_loss = F.l1_loss(sr, hr) * self.hparams.pixel_loss_weight
            content_loss = contentloss(sr, hr, self.hparams.precision) * self.hparams.content_loss_weight
            # 对抗损失
            d_real = self.discriminator(hr.detach().clone())
            d_fake = self.discriminator(sr)
            fake_loss = F.binary_cross_entropy_with_logits(d_fake - torch.mean(d_real), torch.ones_like(d_fake)) * 0.5
            real_loss = F.binary_cross_entropy_with_logits(d_real - torch.mean(d_fake), torch.zeros_like(d_real)) * 0.5
            adversarial_loss = (fake_loss + real_loss) * self.hparams.adversarial_loss_weight
            g_loss = pixel_loss + content_loss + adversarial_loss
            # 梯度更新
            opt_g.zero_grad()
            self.manual_backward(g_loss)
            opt_g.step()
            scheduler_g.step()
            # -----------------------------------------------训练判别器---------------------------------------------
            for d_param in self.discriminator.parameters():
                d_param.requires_grad = True
            d_real = self.discriminator(hr)
            d_fake = self.discriminator(sr.detach().clone())
            real_loss = F.binary_cross_entropy_with_logits(d_real - torch.mean(d_fake), torch.ones_like(d_real)) * 0.5
            fake_loss = F.binary_cross_entropy_with_logits(d_fake - torch.mean(d_real), torch.zeros_like(d_fake)) * 0.5
            d_loss = real_loss + fake_loss
            # 梯度更新
            opt_d.zero_grad()
            self.manual_backward(d_loss)
            opt_d.step()
            scheduler_d.step()
            # 计算质量评价指标
            quality_index = self.quality_assessment(sr.detach(), hr.detach())
            self.log_dict({"g_loss": g_loss,
                           "d_loss": d_loss,
                           "train_psnr": quality_index["psnr"],
                           "train_ssim": quality_index["ssim"],
                           "train_mssim": quality_index["mssim"],
                           "train_sam": quality_index["sam"]},
                          on_step=True,
                          on_epoch=True,
                          prog_bar=True,
                          batch_size=self.hparams.batch_size)
        else:
            opt = self.optimizers()  # Wraning Reference to 'zero_grad' not found in 'LightningOptimizer | list'
            scheduler = self.lr_schedulers()
            pixel_loss = F.l1_loss(sr, hr) * self.hparams.pixel_loss_weight
            content_loss = contentloss(sr, hr, self.hparams.precision) * self.hparams.content_loss_weight
            train_loss = pixel_loss + content_loss
            # 梯度更新
            opt.zero_grad()
            self.manual_backward(train_loss)
            opt.step()
            scheduler.step()
            # 计算质量评价指标
            quality_index = self.quality_assessment(sr.detach(), hr.detach())
            self.log_dict({"train_loss": train_loss,
                           "train_psnr": quality_index["psnr"],
                           "train_ssim": quality_index["ssim"],
                           "train_mssim": quality_index["mssim"],
                           "train_sam": quality_index["sam"]},
                          on_step=True,
                          on_epoch=True,
                          prog_bar=True,
                          batch_size=self.hparams.batch_size)

    def validation_step(self, batch, batch_idx, flag: str = 'val'):
        lr, hr, fn = batch['lr_data'], batch['hr_data'], batch['fn']
        sr = self(lr)
        quality_index = self.quality_assessment(sr.detach().clone(), hr.detach().clone())
        self.log_dict({f'{flag}_psnr': quality_index['psnr'],
                       f'{flag}_ssim': quality_index['ssim'],
                       f'{flag}_mssim': quality_index['mssim'],
                       f'{flag}_sam': quality_index['sam']},
                      on_step=True,
                      on_epoch=True,
                      prog_bar=True,
                      batch_size=self.hparams.batch_size)
        if flag == 'val' and (batch_idx + 1) % self.hparams.flush_samples_every_n_steps == 0:
            self.flush_sample(sr.detach().cpu().numpy(), batch_idx, fn)

    def test_step(self, batch, batch_idx, flag: str = 'test'):
        self.validation_step(batch, batch_idx, flag)

    def predict_step(self, batch, batch_idx, dataloader_idx: int = 0):
        lr, hr, fn = batch['lr_data'], batch['hr_data'], batch['fn']
        sr = self(lr)
        # write predict tif image to disk
        write_path = self.hparams.sr_predict_dir + '/' + self.hparams.model_name
        if not os.path.exists(write_path):
            os.makedirs(write_path)
        for b in range(sr.shape[0]):
            with rasterio.open(f'{self.hparams.dataset_dir}/{self.hparams.sensor_type}/lr/{fn[b]}', 'r') as ls:
                meta = ls.meta.copy()
                meta['height'] = meta['height'] * self.hparams.upscale
                meta['width'] = meta['width'] * self.hparams.upscale
                transform = ls.transform * ls.transform.scale(
                    (ls.width / sr.shape[-1]),
                    (ls.height / sr.shape[-2])
                )
                meta['transform'] = transform

            with rasterio.open(write_path + '/' + fn[b], 'w', **meta) as ss:
                ss.write(sr[b, ...].detach().cpu().numpy())

            bgr_data = sr[b, 0:3, ...].detach().cpu().numpy() * 255.
            linear_data = optimized_linear(bgr_data.astype('uint8'))
            cv2.imwrite(write_path + '/' + fn[b].replace('.tif', '.png'), np.transpose(linear_data, [1, 2, 0]))

    def configure_optimizers(self):
        optimizer = define_optimizers(parameters=self.model.parameters(), args=self.hparams)
        scheduler = define_lr_scheduler(optimizer=optimizer, args=self.hparams)
        if self.hparams.training_frame == 'gan':
            dis_optimizer = define_optimizers(parameters=self.discriminator.parameters(), args=self.hparams)
            dis_scheduler = define_lr_scheduler(optimizer=dis_optimizer, args=self.hparams)
            return ({"optimizer": optimizer,
                     "lr_scheduler": {"scheduler": scheduler,
                                      "interval": "epoch"}},
                    {"optimizer": dis_optimizer,
                     "lr_scheduler": {"scheduler": dis_scheduler,
                                      "interval": "epoch",
                                      "frequency": 1}})
        return {"optimizer": optimizer,
                "lr_scheduler": {"scheduler": scheduler,
                                 "interval": "epoch",
                                 "frequency": 1}}

    def load_model(self, model_name):
        if model_name.find('_') == -1:
            cls_name = model_name.upper()
        else:
            cls_name = ''.join([i.capitalize() for i in model_name.split('_')])
        try:
            cls = getattr(importlib.import_module('.' + model_name, package=__package__),
                          cls_name)
            return self.initialize(cls)
        except ValueError:
            print(f'Invalid Model File Name or Invalid Class Name: '
                  f'model.{model_name}.{cls_name}')

    def initialize(self, cls):
        # First arg is self.
        cls_args = inspect.getfullargspec(cls.__init__).args[1:]
        new_cls_args = {}
        for arg in cls_args:
            if arg in self.kwargs.keys():
                new_cls_args[arg] = self.kwargs[arg]
        return cls(**new_cls_args)

    def flush_sample(self, img, batch_idx, batch_fn):
        save_path = f'{self.hparams.flush_samples_dir}/{self.hparams.model_name}/batch_{batch_idx}'
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        # 对图像进行拉伸
        for i in range(img.shape[0]):
            bgr_data = img[i, 0:3, ...] * 255.
            linear_data = optimized_linear(bgr_data.astype('uint8'))
            file_name = batch_fn[i].replace('.tif', '.png')
            cv2.imwrite(f'{save_path}/{file_name}', np.transpose(linear_data, [1, 2, 0]))

    @staticmethod
    def quality_assessment(pred, target):
        if pred.dtype != target.dtype:
            pred = pred.to(target.dtype)
        psnr_value = peak_signal_noise_ratio(pred, target)
        ssim_value = structural_similarity_index_measure(pred, target)
        mssim_value = multiscale_structural_similarity_index_measure(pred, target)
        sam_value = spectral_angle_mapper(pred, target)

        quality_values = {'psnr': psnr_value,
                          'ssim': ssim_value,
                          'mssim': mssim_value,
                          'sam': sam_value}

        return quality_values

Error messages and logs

# Error messages and logs here please

Environment

Current environment ``` #- Lightning Component: Trainer, LightningModule #- PyTorch Lightning Version: 2.0.0 #- PyTorch Version: 2.0.0+cu117 #- Python version: 3.10 #- OS: Windows #- CUDA/cuDNN version: cu117 #- GPU models and configuration: rtx 3060 6g #- How you installed Lightning(`conda`, `pip`, source): pip ```

More info

PS: I found that flush_logs_every_n_steps in CSVLogger() doesn't work, it's actually the log_every_n_steps parameter in Trainer() that controls the interval of writing logs.

cc @borda

awaelchli commented 1 year ago

How do I get lr_scheduler to update the learning rate at epoch intervals and eliminate warnings. I show all my code below:

To update the scheduler at epoch interval in manual optimization, you do it here:

    def on_train_epoch_end(self) -> None:
        scheduler_g, scheduler_d = self.lr_schedulers()
        scheduler_d.step()
        scheduler_g.step()

Is that what you were looking for? Note, the interval entry in the configuration dict returned in configure_optimizer only applies in automatic optimization.

zxyl1003 commented 1 year ago

Thank you!!! @awaelchli