Open mijosch opened 2 years ago
Do you have a self-contained script which we could use to reproduce the problem?
I will provide a dummy project
This is the dummy code. The data is just random and the training is as short as possible to just produce a model which has been processed by pytorch lightning.
But this is how my program does look like. I also added a time measurement at the end for observing the time needed to save, torch load and load_state_dict. If you want to try the different sizes, just change the values variable, it is preset to the 10100 columns I have in my table.
from numpy.lib.npyio import save
import pandas as pd
from random import random
from pytorch_forecasting.data.encoders import TorchNormalizer, MultiNormalizer
from pytorch_forecasting.metrics import MultiLoss
import torch
intervals = 255
values = 10100
data = {}
data["time_idx"] = []
data["group"] = []
for item in range(intervals):
data["time_idx"].append(item)
data["group"].append(1)
for item2 in range(values):
data[str(item2)] = []
for item in range(intervals):
data[str(item2)].append(random())
print(len(data.keys()))
df = pd.DataFrame(data)
print("dataframe build")
# imports for training
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# import dataset, network to train and metric to optimize
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss
# load data: this is pandas dataframe with at least a column for
# * the target (what you want to predict)
# * the timeseries ID (which should be a unique string to identify each timeseries)
# * the time of the observation (which should be a monotonically increasing integer)
data = df
# define the dataset, i.e. add metadata to pandas dataframe for the model to understand it
max_encoder_length = 48
max_prediction_length = 24
training_cutoff = df["time_idx"].max() - max_prediction_length # day for cutoff
targets = list(df.columns[2:])
tn = []
for _ in range(values):
tn.append(TorchNormalizer())
training = TimeSeriesDataSet(
data[lambda x: x.time_idx <= training_cutoff],
time_idx= "time_idx", # column name of time of observation
target=targets , # column name of target to predict
group_ids=["group"], # column name(s) for timeseries IDs
max_encoder_length=max_encoder_length, # how much history to use
max_prediction_length=max_prediction_length, # how far to predict into future
# covariates static for a timeseries ID
static_categoricals=[],
static_reals=[],
# covariates known and unknown in the future to inform prediction
time_varying_known_categoricals=[],
time_varying_known_reals=["time_idx"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=targets,
target_normalizer=MultiNormalizer(tn),
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
print("Timeseries build")
# create validation dataset using the same normalization techniques as for the training dataset
validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
# convert datasets to dataloaders for training
batch_size = 2
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
# create PyTorch Lighning Trainer with early stopping
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateMonitor()
trainer = pl.Trainer(
max_epochs=100,
gpus=None, # run on CPU, if on multiple GPUs, use accelerator="ddp"
gradient_clip_val=0.1,
limit_train_batches=30, # 30 batches per epoch
fast_dev_run=True,
#callbacks=[lr_logger, early_stop_callback],
logger=TensorBoardLogger("lightning_logs")
)
ln = []
o_s = []
for _ in range(values):
ln.append(QuantileLoss([0.25,0.5,0.75]))
o_s.append(3)
# define network to train - the architecture is mostly inferred from the dataset, so that only a few hyperparameters have to be set by the user
tft = TemporalFusionTransformer.from_dataset(
# dataset
training,
# architecture hyperparameters
hidden_size=32,
attention_head_size=4,
dropout=0.1,
hidden_continuous_size=16,
# loss metric to optimize
loss=MultiLoss(ln),
# logging frequency
#log_interval=2,
# optimizer parameters
learning_rate=0.03,
reduce_on_plateau_patience=4,
output_size=o_s
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
# fit the model on the data - redefine the model with the correct learning rate if necessary
trainer.fit(
tft, train_dataloader=train_dataloader#, val_dataloaders=val_dataloader,
)
import time
start = time.time()
with open("model.mdl","wb") as f:
torch.save(tft.state_dict(),f)
stop = time.time()
print("Save time",stop-start)
start = time.time()
with open("model.mdl","rb") as f:
st_dict = torch.load(f)
stop = time.time()
print("torch load time",stop-start)
start = time.time()
tft.load_state_dict(st_dict)
stop = time.time()
print("load_state_dict time",stop-start)
sorry for butting in, I was just looking for examples of TemporalFusionTransformer
model training and came across this issue. When I ran the code above I got
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`
It seems like the documentation for TemporalFusionTransformer is not maintained any more. E.g. this tutorial https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html is no longer useable. This issue has not been active for the last 18 months either.
š Bug
posted here ->
After saving a checkpoint of a large pytorch_forecasting model (10'100 targets forecasting based on previous 10'000 values ), loading the checkpoint lasts forever (i killed loading after 16 hours).
With 5 targets forecasting based on previous 5 values -> works With 50 targets forecasting based on previous 50 values -> works With 500 targets forecasting based on previous 500 values -> works With 5000 targets forecasting based on previous 5000 values -> in test
Atm. trying to find the amount where it stops working.
To Reproduce The model hast 10'100 target values, the prediction is 24 time intervals based on 48 past time intervals. The predictions are quantiles(0.25,0.5,0.75)
Expected behavior the model can be loaded in a reasonable timeframe.
Environment PyTorch 1.4.9 PyTorch 1.9.1 Python version: 3.8.10 OS (e.g., Linux): Ubuntu 20.04 CUDA/cuDNN version: - GPU models and configuration: - How you installed PyTorch (conda, pip, source): pip Additional context Loading and saving of small models works well.
Tests
Model is a temporal fusion transformer
2 targets from 2 columns : torch.load -> 0.032s setting state_dict -> 0.0006s state_dict load -> 0.017 tft.load_from_checkoint -> 0.076s
10 targets from10 columns: torch.load -> 0.08s setting state_dict -> 0.00119s state_dict load -> 0.068 tft.load_from_checkoint -> 0.195s
100 targets from 100 columns: torch.load -> 0.61s setting state_dict -> 0.0095s state_dict load -> 3.015 tft.load_from_checkoint -> 3.74s
1000 targets from 1000 columns: torch.load -> 6.4s setting state_dict -> 0.122s state_dict load -> 352.154 tft.load_from_checkoint -> 351.1s
2000 targets from 2000 columns: torch.load -> 12.3s setting state_dict -> 0.25s state_dict load -> ---- tft.load_from_checkoint -> -----
3000 targets from 3000 columns: torch.load -> 16.7s setting state_dict -> 0.39s state_dict load -> ---- tft.load_from_checkoint -> -----
2 targets from 10100 columns: torch.load -> 17.12s setting state_dict -> 0.415s tft.load_from_checkoint -> -----
where setting state_dict : -->>tft.state_dict = ckpnt['state_dict']
where state_dict load : -->>tft.load_state_dict(ckpnt['state_dict'])
The 2 targets from 2 columns already has 292 keys in the state_dict. The 10 targets from 10 columns has 668 keys in the state_dict. The 100 targets from 100 columns has 4898 keys in the state_dict. The 1000 targets from 1000 columns has 47198 keys in the state_dict.
After a lot of checks, it seems like the load_state_dict function is quadratically slowing down what makes big models, like I have to use, unusable with pytorch.
Is there a way around or can this be fixed?
cc @VitalyFedyunin @ngimel @mruberry