awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.64k stars 755 forks source link

enable_checkpointing = False results in MisconfigurationException #3203

Open mraapshockwavemedical opened 4 months ago

mraapshockwavemedical commented 4 months ago

Description

I would like to train a TemporalFusionTransformerEstimator and set trainer_kwargs["enable_checkpointing"] = False. However, on line 196 in torch\model\estimator.py a ModelCheckpoint is created nevertheless and added to the list of callbacks on line 204. This results in an error: MisconfigurationException( lightning.fabric.utilities.exceptions.MisconfigurationException: Trainer was configured with enable_checkpointing=False but found ModelCheckpoint in callbacks list.

I need to disable checkpoints, because I would like to run this on Snowflake which does not allow writing to a filesystem. Workaround ideas for now would be highly appreciated.

Edit Snowflake would allow me to write to /tmp/checkpoints, but it seems to be impossible to set the dir_path of the checkpoint created in estimator.py on lines 195-198:

        monitor = "train_loss" if validation_data is None else "val_loss"
        checkpoint = pl.callbacks.ModelCheckpoint(
            monitor=monitor, mode="min", verbose=True
        )

To Reproduce

import pandas as pd
from gluonts.torch import TemporalFusionTransformerEstimator
from gluonts.dataset.pandas import PandasDataset

data = {
    "item_id": [1, 1, 1],
    "ts": ['2024-01-01', '2024-02-01', '2024-03-01'],
    "target": [1, 2, 3]
}
ds = PandasDataset.from_long_dataframe(pd.DataFrame(data), target="target", item_id='item_id', timestamp='ts')
trainer_kwargs = {'enable_checkpointing': False}
estimator = TemporalFusionTransformerEstimator(freq='M', prediction_length=1, trainer_kwargs=trainer_kwargs)
predictor = estimator.train(training_data=ds)

Error message or code output

Traceback (most recent call last):
  File "C:\projects\test\venv\test.py", line 15, in <module>
    predictor = estimator.train(training_data=ds)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\gluonts\torch\model\estimator.py", line 246, in train
    return self.train_model(
           ^^^^^^^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\gluonts\torch\model\estimator.py", line 201, in train_model
    trainer = pl.Trainer(
              ^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\utilities\argparse.py", line 70, in insert_env_defaults
    return fn(self, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\trainer.py", line 431, in __init__
    self._callback_connector.on_trainer_init(
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\callback_connector.py", line 66, in on_trainer_init
    self._configure_checkpoint_callbacks(enable_checkpointing)
  File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\callback_connector.py", line 88, in _configure_checkpoint_callbacks
    raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: Trainer was configured with `enable_checkpointing=False` but found `ModelCheckpoint` in callbacks list.

Environment