Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.3k stars 3.38k forks source link

`LightningModule.log(on_epoch, on_step)`: Hard to get same behavior for train and val? #6770

Closed EricCousineau-TRI closed 3 years ago

EricCousineau-TRI commented 3 years ago

🐛 Bug

I see a table of different default behaviors for log(), which leads me to believe that if I want train/eval to have same behavior (e.g. reduction, frequency, etc.), I could just set them.

However, that doesn't seem to be the case.

Was seeing noisy validation values logged, whereas training values were smoother. I wanted to see if I could make validation be smoothed (reduced) as is done with training, but couldn't figure it out easily from available docs.

Reviewed:

Please reproduce using the BoringModel

To Reproduce

Expected behavior

When I explicitly set the values for .log(), I'd expect that training and validation metrics get logged the same as they do for custom logger.

Environment

See Notebook.

Additional context

N/A

EricCousineau-TRI commented 3 years ago

Main output from notebook:

mode=Mode.CustomLog
train: N=2, batch_idx=0, batch=[0.0, 1.0], p=1.0
  log[step=0]: {'train/loss': -1.0}
train: N=2, batch_idx=1, batch=[2.0, 3.0], p=2.0
  log[step=1]: {'train/loss': -2.0}
val: N=2, batch_idx=0, batch=[0.0, 1.0], p=3.0
  log[step=1]: {'val/loss': -3.0}
val: N=2, batch_idx=1, batch=[2.0, 3.0], p=3.0
  log[step=1]: {'val/loss': -3.0}

mode=Mode.LogDefault
train: N=2, batch_idx=0, batch=[0.0, 1.0], p=1.0
train: N=2, batch_idx=1, batch=[2.0, 3.0], p=2.0
val: N=2, batch_idx=0, batch=[0.0, 1.0], p=3.0
val: N=2, batch_idx=1, batch=[2.0, 3.0], p=3.0
  log[step=1]: {'val/loss': -3.0, 'epoch': 0}

mode=Mode.LogOnStep
train: N=2, batch_idx=0, batch=[0.0, 1.0], p=1.0
train: N=2, batch_idx=1, batch=[2.0, 3.0], p=2.0
val: N=2, batch_idx=0, batch=[0.0, 1.0], p=3.0
  log[step=0]: {'val/loss/epoch_0': -3.0}
val: N=2, batch_idx=1, batch=[2.0, 3.0], p=3.0
  log[step=1]: {'val/loss/epoch_0': -3.0}

mode=Mode.LogOnEpoch
train: N=2, batch_idx=0, batch=[0.0, 1.0], p=1.0
train: N=2, batch_idx=1, batch=[2.0, 3.0], p=2.0
  log[step=1]: {'train/loss': -1.5, 'epoch': 0}
val: N=2, batch_idx=0, batch=[0.0, 1.0], p=3.0
val: N=2, batch_idx=1, batch=[2.0, 3.0], p=3.0
  log[step=1]: {'val/loss': -3.0, 'epoch': 0}

mode=Mode.LogOnBoth
train: N=2, batch_idx=0, batch=[0.0, 1.0], p=1.0
train: N=2, batch_idx=1, batch=[2.0, 3.0], p=2.0
  log[step=1]: {'train/loss_epoch': -1.5, 'epoch': 0}
val: N=2, batch_idx=0, batch=[0.0, 1.0], p=3.0
  log[step=0]: {'val/loss_step/epoch_0': -3.0}
val: N=2, batch_idx=1, batch=[2.0, 3.0], p=3.0
  log[step=1]: {'val/loss_step/epoch_0': -3.0}
  log[step=1]: {'val/loss_epoch': -3.0, 'epoch': 0}
EricCousineau-TRI commented 3 years ago

Simplest workaround: Stop logging for validation_step and do manual reduction validation_epoch_end (colleague's working sln)

Still confused at frequency of training step, but meh, that isn't the issue for now.

tchaton commented 3 years ago

Dear @EricCousineau-TRI,

I am not sure to fully understand your problem.

You could do:

def validation_step(...):
      self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True, prob_bar=True)
EricCousineau-TRI commented 3 years ago

Howdy @tchaton! Sorry it was indirect, but I had done this using Mode.LogOnEpoch (see example notebook).

This is actually the only behavior (I think?) that is consistent between train and val when on_step and on_epoch are explicitly set; my bug is about the fact that explicitly setting other configurations isn't consistent. See repro code for ConstantMultibody._step, self.mode, and how the logger output is different for Mode.LogOnStep and Mode.LogOnBoth.

(I don't really care about prog_bar=True, as I'd hope it's just an indicator, and doesn't affect when the logger instance is called?)

tchaton commented 3 years ago

Hey @EricCousineau-TRI,

You are using the parameter p to compute the loss which is being optimised while logging. Using your code, I have added test at the bottom to validate everything is properly reduced.

Also, we are applying a running mean of 20 batch for training loss, which might explain why the loss seems more stable.

    class Mode(Enum):
        CustomLog = 1
        LogDefault = 2
        LogOnStep = 3
        LogOnEpoch = 4
        LogOnBoth = 5

    class ConstantMultiply(pl.LightningModule):

        def __init__(self, mode):
            super().__init__()
            self.p = torch.nn.Parameter(torch.tensor([1.0]))
            assert mode in Mode
            self.mode = mode

            self.stages = {"train": [], "val": []}

        def forward(self, x):
            return -self.p

        def _step(self, phase, batch, batch_idx):
            N = len(batch)
            p, = self.p.detach().cpu().numpy()
            batch_print = batch.flatten().cpu().numpy().tolist()
            print(f"{phase}: N={N}, batch_idx={batch_idx}, batch={batch_print}, p={p}")
            # Goal here is to get parameter to increment once per batch.
            loss = self(batch).mean() + 0.1 * batch_idx

            key = f"{phase}/loss"
            if self.mode == Mode.CustomLog:
                self.logger.log_metrics(
                    {key: loss.detach().cpu().item()},
                    step=self.global_step,
                )
            elif self.mode == Mode.LogDefault:
                self.log(key, loss)
            if self.mode == Mode.LogOnStep:
                self.log(key, loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)
            elif self.mode == Mode.LogOnEpoch:
                self.log(key, loss, on_step=False, on_epoch=True, logger=True, prog_bar=True)
            elif self.mode == Mode.LogOnBoth:
                self.log(key, loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
            self.stages[phase].append(loss)
            return loss

        def training_step(self, batch, batch_idx):
            return self._step("train", batch, batch_idx)

        def validation_step(self, batch, batch_idx):
            return self._step("val", batch, batch_idx)

        def configure_optimizers(self):
            # Using naive SGD so that we can see parameter increment, hehehhehehehe.
            optimizer = torch.optim.SGD(self.parameters(), lr=1.0)
            return optimizer

    class PrintLogger(LightningLoggerBase):

        def __init__(self):
            super().__init__()

        @property
        def experiment(self):
            return "lightning_logs"

        def log_metrics(self, metrics, step):
            print(f"  log[step={step}]: {metrics}")

        def log_hyperparams(self, params):
            pass

        @property
        def name(self):
            return self.experiment

        @property
        def version(self):
            return 0

    @torch.no_grad()
    def create_dataset(count):
        # These values don't actually matter.
        xs = [torch.tensor([float(i)]) for i in range(count)]
        return xs

    def main():
        count = 4
        dataset = create_dataset(count)
        N = 2
        dataloader = DataLoader(dataset, batch_size=N, shuffle=False)
        num_batches = len(dataloader)

        for mode in Mode:
            logger = PrintLogger()
            # Recreate trainer each time.
            trainer = pl.Trainer(
                max_epochs=1,
                progress_bar_refresh_rate=0,
                logger=logger,
                flush_logs_every_n_steps=num_batches,
                weights_summary=None,
                num_sanity_val_steps=0,
            )
            print(f"mode={mode}")
            model = ConstantMultiply(mode)
            trainer.fit(
                model,
                train_dataloader=dataloader,
                val_dataloaders=dataloader,
            )

            assert model.stages["train"][0] == -1
            assert model.stages["train"][1] == -1.900
            assert model.stages["val"][0] == -3
            assert model.stages["val"][1] == -2.9000

            def mean(values):
                return torch.mean(torch.stack(values))

            # val_loss on step is not present in callback metrics as it doesn't make sense
            if mode == Mode.LogDefault:
                assert trainer.callback_metrics["train/loss"] == model.stages["train"][1]
                assert trainer.callback_metrics["val/loss"] == mean(model.stages["val"])

            elif mode == Mode.LogOnStep:
                assert trainer.callback_metrics["train/loss"] == model.stages["train"][1]

            elif mode == Mode.LogOnEpoch:
                assert trainer.callback_metrics["train/loss"] == mean(model.stages["train"])
                assert trainer.callback_metrics["val/loss"] == mean(model.stages["val"])

            elif mode == Mode.LogOnBoth:
                assert trainer.callback_metrics["train/loss"] == mean(model.stages["train"])
                assert trainer.callback_metrics["val/loss"] == mean(model.stages["val"])

                assert trainer.callback_metrics["val/loss_epoch"] == mean(model.stages["val"])
                assert trainer.callback_metrics["train/loss_epoch"] == mean(model.stages["train"])

                assert trainer.callback_metrics["train/loss_step"] == model.stages["train"][1]

    main()
tchaton commented 3 years ago

Closing this issue as it is working as expected, feel free to re-open it.