Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.81k stars 3.33k forks source link

`EarlyStopping` triggered before `min_epochs` corrupts WandB logs, dropping previously logged values. #18251

Closed thesofakillers closed 1 year ago

thesofakillers commented 1 year ago

Bug description

This morning I woke up to a very weird result.

I have a PL module/trainer/dm setup that

When i started training last night, everything was working fine, my training metrics were being logged every 50 steps and loss was decreasing.

This morning I woke up to see that

  1. EarlyStopping had been triggered before min_epochs. Ok, totally fine, training continued until min_epochs was reached
  2. All the logs before the EarlyStopping trigger got "corrupted", WandB no longer showed training metrics every 50 steps but every 2050 steps (!!!!!) instead. It seemed several logging steps had been dropped or something.

W B Chart 07_08_2023, 10_50_45

The high frequency logging you see at the end of training should be observed everywhere and was observed last night at the beginning of training before early stopping was triggered.

What version are you seeing the problem on?

v1.9

How to reproduce the bug

I have not been able to reproduce this bug with a smaller dataset/setup. Advice with how to try to reproduce it is welcome.

Environment

    - black:             23.3.0
    - bleach:            6.0.0
    - blosc:             1.11.1
    - calvin:            0.0.2
    - calvin-env:        0.0.1
    - certifi:           2023.5.7
    - cffi:              1.15.1
    - charset-normalizer: 3.1.0
    - click:             8.1.3
    - cloudpickle:       2.2.1
    - cmake:             3.18.4
    - colorama:          0.4.6
    - colorlog:          6.7.0
    - comm:              0.1.3
    - contourpy:         1.1.0
    - cycler:            0.11.0
    - datasets:          2.13.1
    - debugpy:           1.6.7
    - decorator:         4.4.2
    - defusedxml:        0.7.1
    - diffusers:         0.14.0
    - dill:              0.3.6
    - docker-pycreds:    0.4.0
    - exceptiongroup:    1.1.1
    - executing:         1.2.0
    - farama-notifications: 0.0.4
    - fastjsonschema:    2.17.1
    - filelock:          3.12.2
    - fonttools:         4.40.0

    - pluggy:            1.2.0
    - procgen:           0.10.7
    - proglog:           0.1.10
    - prometheus-client: 0.17.0
    - prompt-toolkit:    3.0.38
    - protobuf:          4.23.3
    - psutil:            5.9.5
    - ptyprocess:        0.7.0
    - pure-eval:         0.2.2
    - pyarrow:           12.0.1
    - pybullet:          3.2.5
    - pycollada:         0.7.2
    - pycparser:         2.21
    - pygame:            2.5.0
    - pyglet:            2.0.8
    - pygments:          2.15.1
    - pyhash:            0.9.3
    - pyopengl:          3.1.0
    - pyparsing:         3.1.0
    - pyrender:          0.1.45
    - pyrsistent:        0.19.3
    - pytest:            7.4.0
    - pytest-lazy-fixture: 0.6.3
    - pytest-profiling:  1.7.0
    - python-dateutil:   2.8.2
    - python-json-logger: 2.0.7
    - pytoolconfig:      1.2.5
    - pytorch-lightning: 1.9.5
    - pytz:              2023.3
    - pyyaml:
    - triton:            2.0.0
    - typing-extensions: 4.7.0
    - tzdata:            2023.3
    - urdfpy:            0.0.22
    - uri-template:      1.3.0
    - urllib3:           2.0.3
    - wandb:             0.15.4
    - wcwidth:           0.2.6
    - webcolors:         1.13
    - webencodings:      0.5.1
    - websocket-client:  1.6.1
    - wheel:             0.40.0
    - widgetsnbextension: 4.0.7
    - xxhash:            3.2.0
    - y-py:              0.5.9
    - yarl:              1.9.2
    - ypy-websocket:     0.8.2
    - zipp:              3.15.0

More info

Here is my setup:

    logger = pl.loggers.WandbLogger(
        config=args,
        log_model=False,
    )
    early_stopping = pl.callbacks.early_stopping.EarlyStopping(
        monitor="textual/val_loss", mode="min", strict=False
    )
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor='textual/val_loss', save_last=True
    )

    trainer = pl.Trainer(
        min_epochs=5,
        max_epochs=10,
        deterministic=True,
        logger=logger,
        callbacks=[early_stopping, checkpoint_callback],
        log_every_n_steps=50,
        val_check_interval=0.5,
        check_val_every_n_epoch=1,
        precision=16,
    )

In my PL LightningModule, I call self.log('visual/train_loss', loss, batch_size=batch_size) in the training_step, and similarly self.log('visual/train_loss', loss, batch_size=batch_size)in thevalidation_step`. Not really sure what I am doing wrong. I ran the same exact run (same seed and everything) and verified that the logger indeed logs every 50 steps at the beginning, indicating that these indeed do get dropped once EarlyStopping gets triggered.

thesofakillers commented 1 year ago

Note: zooming into the plot reveals that the logging is indeed occurring every 50 steps during the first part of training, but simply not plotted when zoomed out for some reason. Seems to be a wandb bug. Nevertheless, weird that it gets triggered when min_epochs and early stopping conflict.

Before zooming in on 2 identical runs (brown one has not had the conflict occur yet)

image

After zooming in:

image

They are indeed identical.

thesofakillers commented 1 year ago

Closing this and opening in wandb repo. Apologies.