unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.89k stars 854 forks source link

[BUG] Unable to save TFTModel #1928

Closed meteoDaniel closed 7 months ago

meteoDaniel commented 1 year ago

Describe the bug Dear darts-Team, I try to save a TFTModel with .save() and retrieve the following error message:

In [1]: self.model.save(self.run.info.artifact_uri.replace("file://", "") + "/model.pt")
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 1
----> 1 self.model.save(self.run.info.artifact_uri.replace("file://", "") + "/model.pt")

File /usr/local/lib/python3.10/dist-packages/darts/models/forecasting/torch_forecasting_model.py:1518, in TorchForecastingModel.save(self, path)
   1516 # save the TorchForecastingModel (does not save the PyTorch LightningModule, and Trainer)
   1517 with open(path, "wb") as f_out:
-> 1518     torch.save(self, f_out)
   1520 # save the LightningModule checkpoint
   1521 path_ptl_ckpt = path + ".ckpt"

File /usr/local/lib/python3.10/dist-packages/torch/serialization.py:441, in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
    439 if _use_new_zipfile_serialization:
    440     with _open_zipfile_writer(f) as opened_zipfile:
--> 441         _save(obj, opened_zipfile, pickle_module, pickle_protocol)
    442         return
    443 else:

File /usr/local/lib/python3.10/dist-packages/torch/serialization.py:653, in _save(obj, zip_file, pickle_module, pickle_protocol)
    651 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    652 pickler.persistent_id = persistent_id
--> 653 pickler.dump(obj)
    654 data_value = data_buf.getvalue()
    655 zip_file.write_record('data.pkl', data_value, len(data_value))

AttributeError: Can't pickle local object 'LayerSummary._register_hook.<locals>.hook'

torch.save(self.model.model) works fine, So i think it has something to do with the attributes stored in the TFTModel class.

To Reproduce Currently no stand-alone snippet to reproduce the issue avail.

Expected behavior Model will be saved without Error

System (please complete the following information):

Additional context

I removed trainer from the TFTModel but this does not fixed the problem.

madtoinou commented 1 year ago

Hi @meteoDaniel,

Can you please share a code snippet (with a dummy dataset) so that we can reproduce your problem? Also, are you using custom hooks or lambda function in the model constructor (in add_covariates for example)?

meteoDaniel commented 1 year ago

Dear @madtoinou , I will try my best to give you an insight of what we are doing. But unfortunaetly it sounds that you have heard the first time from that issue.

So here we go for the pl_trainer_kwargs:

        _trainer_kwargs = dict(
            logger=MLFlowLogger(
                run_name=self.run.info.run_name,
                run_id=self.run.info.run_id,
                experiment_name=self.experiment_name,
                tracking_uri=self.tracking_uri,
                log_model=False,
                artifact_location=str(self.artifact_dir.parent),
            ),
            callbacks=[],
            max_epochs=5,
            accelerator="gpu",
            devices=-1,
            auto_select_gpus=True,
        )

Here we go for the .fit():

        _prepared_data = {
            "series": self.prepared_dataset.training_target,
            "past_covariates": self.prepared_dataset.training_past_features,
            "future_covariates": self.prepared_dataset.training_future_features,
            "val_series": self.prepared_dataset.validation_target,
            "val_past_covariates": self.prepared_dataset.validation_past_features,
            "val_future_covariates": self.prepared_dataset.validation_future_features,
        }

        model.fit(**_prepared_data)

These are all darts TimeSeries objects.

And here is the model init:

from darts.models import TFTModel
import torch
from torch import optim
from darts.utils.likelihood_models import QuantileRegression

def darts_tft(
) -> callable:
    """This is a single layer linear regression without bias/offset correction"""
    model = TFTModel(
        input_chunk_length=24,
        output_chunk_length=72,
        hidden_size=64,
        hidden_continuous_size=8,
        optimizer_cls=optim.AdamW,
        num_attention_heads=2,
        dropout=0.1,
        lstm_layers=4,
        batch_size=64,
        n_epochs=5,
        add_relative_index=False,
        add_encoders=None,
        likelihood=QuantileRegression(
            quantiles=parameter.quantiles
        ), 
        loss_fn=torch.nn.MSELoss(),
        full_attention=True,
        categorical_embedding_sizes={'positional_index': 1},
    )
    return model

I really hope you have an idea what to figure out.

madtoinou commented 1 year ago

I did not manage to reproduce your problem with the following snippet but everything worked as intended.

from darts.models import TFTModel
import torch
from torch import optim
from darts.utils.likelihood_models import QuantileRegression
from pytorch_lightning.loggers import MLFlowLogger

from darts.utils.timeseries_generation import sine_timeseries

_trainer_kwargs = dict(
            logger=MLFlowLogger(
                log_model=False,
            ),
            callbacks=[],
            max_epochs=5,
        )

training_target = sine_timeseries(length=100)
training_past_features = sine_timeseries(length=100)
training_future_features = sine_timeseries(length=200)

validation_target = sine_timeseries(length=100)
validation_past_features = sine_timeseries(length=100)
validation_future_features = sine_timeseries(length=200)

def darts_tft():
    """This is a single layer linear regression without bias/offset correction"""
    model = TFTModel(
        input_chunk_length=24,
        output_chunk_length=72,
        hidden_size=64,
        hidden_continuous_size=8,
        optimizer_cls=optim.AdamW,
        num_attention_heads=2,
        dropout=0.1,
        lstm_layers=4,
        batch_size=64,
        n_epochs=5,
        add_relative_index=False,
        add_encoders=None,
        likelihood=QuantileRegression(
            quantiles=[0.25, 0.5, 0.75]
        ), 
        loss_fn=torch.nn.MSELoss(),
        full_attention=True,
        categorical_embedding_sizes={'positional_index': 1},
        pl_trainer_kwargs=_trainer_kwargs,
        save_checkpoints=True
    )
    return model

model = darts_tft()

_prepared_data = {
            "series": training_target,
            "past_covariates": training_past_features,
            "future_covariates": training_future_features,
            "val_series": validation_target,
            "val_past_covariates": validation_past_features,
            "val_future_covariates": validation_future_features,
        }

model.fit(**_prepared_data)
meteoDaniel commented 1 year ago

Dear @madtoinou I moved this issue to my colleague, he will take a closer look again and will come back to you with more insights as soon as we have one. .

FalcoWeich commented 1 year ago

Hi, here is "the colleague". It took me a day to find it. The solution was to put the input data into float32 precision (before it was float64). And I guess that there was a problem with the example_input_array which probably hang in the ModelSummary hook (which is the error, because it hangs anywhere there.

Below a code snippet that analyzes the LayerSummary. See that the out_sizes are all unknown.

import lightning.pytorch.utilities.model_summary as model_summary
ms = model_summary.ModelSummary(model.model)
print('?' in ms.out_sizes)
# the output_sizes for all layers are not available (what about example input array?)
# detach hook to avoid pickling error, but still not working
model_summary.LayerSummary(model.model).detach_hook()

We initialize our TimeSeries objects using the from_times_and_values option.

TimeSeries.from_times_and_values(
            values=np.float32(some_df_data.data[
                :, ~selected_cols
            ]),
            static_covariates=static_covariates,
            times=some_df_data.index.get_level_values(COLUMN_LEVEL1).rename(COLUMN_LEVEL0),
            columns=some_df_data.columns[
                ~selected_cols
            ],
            freq=target_freq,
            fill_missing_dates=True,
        )
madtoinou commented 1 year ago

Thank you @FalcoAlitiq for sharing your solution, glad that you managed to solve it.

darts models store an input example in the train_sample attribute but it's a tuple containing the target and the various covariates (hence non-usable by ModelSummary). The method responsible for converting this tuple to an actual array/tensor is _get_batch_prediction().

Since this example_input_array is also something that is required to export models to the ONNX format, I have a branch where the array concatenation occurs in a dedicated method which could actually be exposed as an example_input_array property but it's stale because of higher priority tasks and time constraints. I'll try to work on it in the upcoming weeks, it might solve both problems.

FalcoWeich commented 1 year ago

Since the problem occured as connected with the precision you should exploit this idea more. I do not have the time, but I guess that some combinations of different precisions in the covariates (maybe also mixed from TimeSeries objects?) could be interesting. Maybe a new model-parameter "dtype/data_precision" could be delivered which controls precision (for alle covariates and the input example accordingly - that'd be very useful (imho)!

madtoinou commented 1 year ago

As a way to automatically convert the timeseries passed to a model? I am not sure to understand what you mean by "combination of different precisions"?

In general the dtype of the model is picked based on the dtype of the ts used during training, did you move the mode from gpu to cpu or just using different datasets between training and inference time?

FalcoWeich commented 1 year ago

I am not quite sure now still, but as far as I remember we had float64 only data in our TimeSeries and trained our model as usual. Only the saving failed because some hook is stucked somewhere as the error message suggests.

My suggestion is to have a parameter which is delivered by initialization of the TFTModel (or other ForecastingModels) which sets the floating point precision and checks if every tensor internally as well as all covariates are on the same precision.

dennisbader commented 7 months ago

The dtype is fixed once the model was created/trained for the first time. Dtypes have to match before and after saving. The models will raise an error if not.