Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.22k stars 3.38k forks source link

Callback not invoked for the validation set with DDP #15028

Closed athn-nik closed 1 year ago

athn-nik commented 2 years ago

Bug description

I have a callback that is supposed to called during training and validation set. However, the validation part of it is never invoked. The callback is about rendering videos and logging them in wandb. The callback is supposed to happened every n epochs after the train/val step ends. The callback is never called since the videos which are supposed to create are never saved on the disk, so not wandb but pl bug. This happens when using Gpus > 1 with ddp.

There are no error messages. The videos are not created and none of the print statements indicate that this part of the callback is ever accessed.

How to reproduce the bug

The code of the callback:

class RenderCallback(Callback):
    def __init__(self, bm_path: str = None,
                 path: str = "visuals",
                 logger_type: str = "wandb",
                 save_last: bool = True,
                 vid_format: str = "mp4",
                 every_n_epochs: int = 20,
                 num_workers: int = 0,
                 nvids_to_save: int = 5,
                 fps: float = 30.0,
                 modelname = 'space') -> None:

        if logger_type == "wandb":
            self.log_to_logger = log_to_wandb
        elif logger_type == "tensorboard":
            self.log_to_logger = log_to_tensorboard
        elif logger_type == "none":
            self.log_to_logger = log_to_none
        else:
            raise NotImplementedError("This logger is unknown, please use tensorboard or wandb.")

        self.logger_type = logger_type
        self.path = Path(path)
        self.path.mkdir(exist_ok=True)

        self.fps = fps
        self.nvids = nvids_to_save
        self.every_n_epochs = every_n_epochs
        self.num_workers = num_workers
        self.vid_format = vid_format
        self.save_last = save_last
        self.model = modelname

        from hydra.utils import get_original_cwd
        # self.labels_dict = read_json(f'{get_original_cwd()}/deps/inference/labels.json')
        self.labels_train = read_json(f'{get_original_cwd()}/deps/inference/labels_train_spatial.json')
        self.labels_val = read_json(f'{get_original_cwd()}/deps/inference/labels_val_spatial.json')

        if bm_path is not None:
            self.body_model_path = Path(bm_path) / 'smpl_models'

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,
                           **kwargs) -> None:
        if trainer.is_global_zero:
            return self.call_renderer("train", trainer, pl_module)

    def on_validation_epoch_end(self, trainer: Trainer,
                                pl_module: LightningModule) -> None:
        if trainer.is_global_zero:
            return self.call_renderer("val", trainer, pl_module)

    def on_test_epoch_end(self, trainer: Trainer,
                          pl_module: LightningModule) -> None:
        return self.call_renderer("test", trainer, pl_module)

    def call_renderer(self, split: str, trainer: Trainer,
                      pl_module: LightningModule) -> None:
        if trainer.sanity_checking:
            return

        if self.nvids is None or self.nvids == 0:
            return

        # mid-epoch starting for finetuning
        # if pl_module.store_examples[split] is None:
        #     return

        logger.debug(f"Render {split} samples and log to {self.logger_type}")

        # Don't log epoch 0
        if trainer.current_epoch == 0 or trainer.current_epoch % self.every_n_epochs != 0:
            # Log last one (return = don't log, if it is not the last one)
            if trainer.current_epoch != (trainer.max_epochs - 1):
                return
            # Don't log last one if we don't want it
            elif not self.save_last:
                return
        # Prepare the folder
        folder = "epoch_" + str(trainer.current_epoch).zfill(3)
        folder = self.path / folder
        folder.mkdir(exist_ok=True)

        # Extract the stored data
        store_examples = pl_module.store_examples[split]
        ref_joints_or_verts = store_examples['ref']
        ref_motion_features = store_examples['ref_features']
        keyids_to_render = store_examples['keyids']
        ref_motion_features = ref_motion_features.features
        ref_motion_features = ref_motion_features[:self.nvids]

        # Render + log
        # nvids = min(self.nvids, len(ref_joints_or_verts))

        pl_module.eval()
        if split == 'train': 
            render_list = [self.labels_train[key] for key in keyids_to_render]
            # render_list = [lens_acts for key, lens_acts in self.labels_train.items() if key in keyids_to_render]

            texts = [[t0,t1] for t0, t1, _ in render_list]
            lens = [l for _, _, l in render_list]

            jts_T = pl_module.forward_seq(texts, lens, return_type='joints')
            jts_T = [mot.detach().cpu().numpy() for mot in jts_T]

            jts_M = pl_module.forward_motion(ref_motion_features, lens, inference=True, return_type='joints')
            jts_M = [mot.detach().cpu().numpy() for mot in jts_M]

            texts = [f'{t0},{t1} | {keyids_to_render[i]}' for i, (t0, t1) in enumerate(texts)] 

        elif split == 'val':
            render_list = [self.labels_val[key] for key in keyids_to_render]

            # render_list = [lens_acts for key, lens_acts in self.labels_val.items() if key in keyids_to_render]
            texts = [[t0,t1] for t0, t1, _ in render_list]

            lens = [l for _, _, l in render_list]

            jts_T = pl_module.forward_seq(texts, lens, return_type='joints')
            jts_T = [mot.detach().cpu().numpy() for mot in jts_T]

            jts_M = pl_module.forward_motion(ref_motion_features, lens, inference=True, return_type='joints')
            jts_M = [mot.detach().cpu().numpy() for mot in jts_M]

            texts = [f'{t0},{t1} | {keyids_to_render[i]}' for i, (t0, t1) in enumerate(texts)] 

        # self.labels_val 
        # render_list_train = list(self.labels_train.values())[:self.nvids]
        # render_list_val =  list(self.labels_val.values())[:self.nvids]

        # for set_name, keyids in zip(['train', 'val'], [render_list_train, render_list_val]):
        #     for keyid in keyids:

        import multiprocessing
        list_of_logs = []
        num_workers = min(self.num_workers, 3 * self.nvids)
        with multiprocessing.Pool(num_workers) as pool:
            iterable = ((joints[index], name, index, split,
                         folder, self.fps, description, trainer.current_epoch)
                        for joints, name in zip([jts_T, jts_M, ref_joints_or_verts],
                                                ['text', 'motion', 'ref'])
                        for index, description in zip(range(self.nvids), texts))
            for output, fig_number, name, desc in pool.imap_unordered(render_and_save, iterable):
                split_set = desc.split('_')[-1]
                log_name = f"visuals/{name}/{split_set}_{fig_number}"
                list_of_logs.append((output, log_name, desc))

                train_logger = pl_module.logger.experiment

                # self.log_to_logger(path=output, log_name=log_name, caption=desc,
                #                    fps=self.fps, global_step=trainer.current_epoch,
                #                    train_logger=train_logger, vid_format=self.vid_format)
        import operator
        list_of_logs.sort(key=operator.itemgetter(2))
        for vid_path, panel_name, text_desc in list_of_logs: 
            log_name_start, branch_or_gt, _ = panel_name.split('/')
            vid_id = vid_path.split('/')[-1].split('_')[-1].split('.')[0]
            log_name = f'{log_name_start}_{split}/sample-{vid_id}/{branch_or_gt}'
            self.log_to_logger(path=vid_path, log_name=log_name, caption=text_desc,
                                fps=self.fps, global_step=trainer.current_epoch,
                                train_logger=train_logger, vid_format=self.vid_format)

Error messages and logs

There are no error messages, the code is just never invoked.

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

awaelchli commented 2 years ago

@athn-nik Is the validation loop running for you? Here are some reasons why validation may not run:

Can you check or share your Trainer settings, ideally the full script?

athn-nik commented 2 years ago

@awaelchli thanks for grabbing this. limit_val_batches is always > 0. Here are the args:

auto_select_gpus: true
strategy: null # 'ddp' for multi gpu 
benchmark: False
max_epochs: 1001
accelerator: gpu
devices: 1
log_every_n_steps: 1
deterministic: False
detect_anomaly: False
enable_progress_bar: True
check_val_every_n_epoch: 25
limit_train_batches: 1.0
limit_val_batches: 1.0
num_sanity_val_steps: 2

I have a dataloader and all the other validation relevant metrics and losses are calculated and logged in wandb and stdout. Trainer code:

class BaseModel(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters(logger=False)

        # Save visuals, one validation step per validation epoch
        self.store_examples = {"train": None,
                               "val": None}
        # Need to define:
        # forward
        # allsplit_step()
        # metrics()
        # losses()

    def __post_init__(self):
        trainable, nontrainable = 0, 0
        for p in self.parameters():
            if p.requires_grad:
                trainable += np.prod(p.size())
            else:
                nontrainable += np.prod(p.size())
        self.hparams.n_params_trainable = trainable
        self.hparams.n_params_nontrainable = nontrainable

    def training_step(self, batch, batch_idx):
        return self.allsplit_step("train", batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self.allsplit_step("val", batch, batch_idx)

    def test_step(self, batch, batch_idx):
        return self.allsplit_step("test", batch, batch_idx)

    def allsplit_epoch_end(self, split: str, outputs):
        loss_tracker = self.tracker[split]
        loss_dict = loss_tracker.compute()
        loss_tracker.reset()

        dico = {loss_tracker.loss2logname(loss, split): value.item()
                for loss, value in loss_dict.items()}
        # workaround for LR, assuming 1 optimizer, 1 scheduler, very weak
        curr_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        dico.update({'Learning Rate': curr_lr})

        dico.update({"epoch": float(self.trainer.current_epoch),
                     "step": float(self.trainer.current_epoch)})
        if split == "val":
            metrics_dict = self.metrics.compute()

            dico.update({f"Metrics/{metric}": value for metric, value in metrics_dict.items() if '_mean_' in metric})
        self.log_dict(dico)

    def training_epoch_end(self, outputs):
        return self.allsplit_epoch_end("train", outputs)

    def validation_epoch_end(self, outputs):
        return self.allsplit_epoch_end("val", outputs)

    def test_epoch_end(self, outputs):
        return self.allsplit_epoch_end("test", outputs)

    def configure_optimizers(self):
        optim_dict = {}
        optimizer = instantiate(self.hparams.optim, params=self.parameters())
        optim_dict['optimizer'] = optimizer

        if self.hparams.lr_scheduler == 'reduceonplateau':
            optim_dict['lr_scheduler'] = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=1e-3)
            optim_dict['monitor'] = 'losses/total/train'
        elif self.hparams.lr_scheduler == 'steplr':
            optim_dict['lr_scheduler'] = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100)

        return optim_dict 

The relevant all_split_step which is called is model dependent and an example could be:

def allsplit_step(self, split: str, batch, batch_idx):
        # Prepare the generated motion features
        length = batch["length"]
        input_motion_feats = batch["datastruct"]

        total_loss, loss_dict = self.losses[split](...)
        if split == 'val':
            self.metrics(input_motion_feats.detach().joints,
                         output_features_T.detach().joints,
                         length)

        if batch_idx == 0:
            nvids = self.hparams.nvids_to_save
            if nvids is not None and nvids != 0:
                del self.store_examples[split]
                lengths = batch['length'][:nvids]
                keyids = batch['keyid'][:nvids]
                motion_features = batch['datastruct']
                def prepare_pos(x):
                    x = x.detach().joints[:nvids]
                    x = x.cpu().numpy()
                    return remove_padding(x, lengths)
                def prepare_verts(x):
                    x = x.detach().vertices[:nvids]
                    x = x.cpu().numpy()
                    return remove_padding(x, lengths)

                self.store_examples[split] = { "text": batch["text"][:nvids] }
                self.store_examples[split].update({
                    'ref': prepare_pos(input_motion_feats),
                    'ref_features': motion_features.detach(),
                    'keyids': keyids
                })

        self.tracker[split].update(loss_dict)
        return total_loss

Where the stote_examples is what is used from the callback when the epoch ends.

awaelchli commented 2 years ago

Are you passing the dataloader correctly? Like

trainer.fit(model, train_dataloader, val_dataloader)

athn-nik commented 2 years ago

Yes, I do this and the dataloader is implemented via LightningDataModule. The callback is not called as it does not enter the validation_on_epoch_end method of the callback. It does so when I use a single GPU though, and also the same method is normally accesed for the base model attached above.

awaelchli commented 2 years ago

@athn-nik I cannot reproduce this. Here is a runnable example based on your configuration (but I removed all code that was incomplete for me to use).

import os

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, Callback

class RenderCallback(Callback):

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,
                           **kwargs) -> None:
        if trainer.is_global_zero:
            print("on train epoch end in callback")

    def on_validation_epoch_end(self, trainer: Trainer,
                                pl_module: LightningModule) -> None:
        if trainer.is_global_zero:
            # return self.call_renderer("val", trainer, pl_module)
            print("on val epoch end in callback")

    def on_test_epoch_end(self, trainer: Trainer,
                          pl_module: LightningModule) -> None:
        # return self.call_renderer("test", trainer, pl_module)
        print("on test epoch end in callback")

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return Adam(self.parameters())

    def training_epoch_end(self, outputs):
        print("train epoch end on epoch", self.current_epoch)

    def validation_epoch_end(self, outputs):
        print("val epoch end on epoch", self.current_epoch)

    def test_epoch_end(self, outputs):
        print("test epoch end")

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        # auto_select_gpus=True
        strategy="ddp",
        benchmark=False,
        max_epochs=1001,
        accelerator="cpu",
        devices=2,
        log_every_n_steps=1,
        deterministic=False,
        detect_anomaly=False,
        enable_progress_bar=False,
        check_val_every_n_epoch=25,
        limit_train_batches=1.0,
        limit_val_batches=1.0,
        num_sanity_val_steps=2,
        callbacks=RenderCallback(),
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

if __name__ == "__main__":
    run()

As you can see from the logs I included, the validation epoch end hooks are called every 25 epochs as specified in the trainer:

....
train epoch end on epoch 23
train epoch end on epoch 23
on train epoch end in callback
val epoch end on epoch 24  <--- here
val epoch end on epoch 24  <--- here
on val epoch end in callback  <--- here
train epoch end on epoch 24
train epoch end on epoch 24
on train epoch end in callback
train epoch end on epoch 25
train epoch end on epoch 25
on train epoch end in callback
....

Please note that in your render method, you have early returns based on some conditions. Can you please check again that your observations are correct and that you were not just tricked by some missing logs?

athn-nik commented 1 year ago

Thanks a lot for your help seems likes there is a mismatch between my epoch indices checks.