pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.22k stars 21.56k forks source link

loading large model not finished after 16 hours #66787

Open mijosch opened 2 years ago

mijosch commented 2 years ago

šŸ› Bug

posted here ->

After saving a checkpoint of a large pytorch_forecasting model (10'100 targets forecasting based on previous 10'000 values ), loading the checkpoint lasts forever (i killed loading after 16 hours).

With 5 targets forecasting based on previous 5 values -> works With 50 targets forecasting based on previous 50 values -> works With 500 targets forecasting based on previous 500 values -> works With 5000 targets forecasting based on previous 5000 values -> in test

Atm. trying to find the amount where it stops working.

To Reproduce The model hast 10'100 target values, the prediction is 24 time intervals based on 48 past time intervals. The predictions are quantiles(0.25,0.5,0.75)

Expected behavior the model can be loaded in a reasonable timeframe.

Environment PyTorch 1.4.9 PyTorch 1.9.1 Python version: 3.8.10 OS (e.g., Linux): Ubuntu 20.04 CUDA/cuDNN version: - GPU models and configuration: - How you installed PyTorch (conda, pip, source): pip Additional context Loading and saving of small models works well.

Tests

Model is a temporal fusion transformer

2 targets from 2 columns : torch.load -> 0.032s setting state_dict -> 0.0006s state_dict load -> 0.017 tft.load_from_checkoint -> 0.076s

10 targets from10 columns: torch.load -> 0.08s setting state_dict -> 0.00119s state_dict load -> 0.068 tft.load_from_checkoint -> 0.195s

100 targets from 100 columns: torch.load -> 0.61s setting state_dict -> 0.0095s state_dict load -> 3.015 tft.load_from_checkoint -> 3.74s

1000 targets from 1000 columns: torch.load -> 6.4s setting state_dict -> 0.122s state_dict load -> 352.154 tft.load_from_checkoint -> 351.1s

2000 targets from 2000 columns: torch.load -> 12.3s setting state_dict -> 0.25s state_dict load -> ---- tft.load_from_checkoint -> -----

3000 targets from 3000 columns: torch.load -> 16.7s setting state_dict -> 0.39s state_dict load -> ---- tft.load_from_checkoint -> -----

2 targets from 10100 columns: torch.load -> 17.12s setting state_dict -> 0.415s tft.load_from_checkoint -> -----

where setting state_dict : -->>tft.state_dict = ckpnt['state_dict']

where state_dict load : -->>tft.load_state_dict(ckpnt['state_dict'])

The 2 targets from 2 columns already has 292 keys in the state_dict. The 10 targets from 10 columns has 668 keys in the state_dict. The 100 targets from 100 columns has 4898 keys in the state_dict. The 1000 targets from 1000 columns has 47198 keys in the state_dict.

After a lot of checks, it seems like the load_state_dict function is quadratically slowing down what makes big models, like I have to use, unusable with pytorch.

Is there a way around or can this be fixed?

cc @VitalyFedyunin @ngimel @mruberry

zou3519 commented 2 years ago

Do you have a self-contained script which we could use to reproduce the problem?

mijosch commented 2 years ago

I will provide a dummy project

mijosch commented 2 years ago

This is the dummy code. The data is just random and the training is as short as possible to just produce a model which has been processed by pytorch lightning.

But this is how my program does look like. I also added a time measurement at the end for observing the time needed to save, torch load and load_state_dict. If you want to try the different sizes, just change the values variable, it is preset to the 10100 columns I have in my table.


from numpy.lib.npyio import save
import pandas as pd
from random import random
from pytorch_forecasting.data.encoders import TorchNormalizer, MultiNormalizer
from pytorch_forecasting.metrics import MultiLoss
import torch

intervals = 255
values = 10100

data = {}
data["time_idx"] = []
data["group"] = []

for item in range(intervals):
    data["time_idx"].append(item)
    data["group"].append(1)

for item2 in range(values):
    data[str(item2)] = []
    for item in range(intervals):
        data[str(item2)].append(random())

print(len(data.keys()))
df = pd.DataFrame(data)
print("dataframe build")

# imports for training
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
# import dataset, network to train and metric to optimize
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss

# load data: this is pandas dataframe with at least a column for
# * the target (what you want to predict)
# * the timeseries ID (which should be a unique string to identify each timeseries)
# * the time of the observation (which should be a monotonically increasing integer)
data = df

# define the dataset, i.e. add metadata to pandas dataframe for the model to understand it
max_encoder_length = 48
max_prediction_length = 24
training_cutoff = df["time_idx"].max() - max_prediction_length  # day for cutoff

targets = list(df.columns[2:])
tn = []
for _ in range(values):
    tn.append(TorchNormalizer())

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx= "time_idx",  # column name of time of observation
    target=targets ,  # column name of target to predict
    group_ids=["group"],  # column name(s) for timeseries IDs
    max_encoder_length=max_encoder_length,  # how much history to use
    max_prediction_length=max_prediction_length,  # how far to predict into future
    # covariates static for a timeseries ID
    static_categoricals=[],
    static_reals=[],
    # covariates known and unknown in the future to inform prediction
    time_varying_known_categoricals=[],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=targets,
    target_normalizer=MultiNormalizer(tn),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,

)
print("Timeseries build")
# create validation dataset using the same normalization techniques as for the training dataset
validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)

# convert datasets to dataloaders for training
batch_size = 2
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

# create PyTorch Lighning Trainer with early stopping
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateMonitor()
trainer = pl.Trainer(
    max_epochs=100,
    gpus=None,  # run on CPU, if on multiple GPUs, use accelerator="ddp"
    gradient_clip_val=0.1,
    limit_train_batches=30,  # 30 batches per epoch
    fast_dev_run=True,
    #callbacks=[lr_logger, early_stop_callback],
    logger=TensorBoardLogger("lightning_logs")
)

ln = []
o_s = []
for _ in range(values):
    ln.append(QuantileLoss([0.25,0.5,0.75]))
    o_s.append(3)

# define network to train - the architecture is mostly inferred from the dataset, so that only a few hyperparameters have to be set by the user
tft = TemporalFusionTransformer.from_dataset(
    # dataset
    training,
    # architecture hyperparameters
    hidden_size=32,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=16,
    # loss metric to optimize
    loss=MultiLoss(ln),
    # logging frequency
    #log_interval=2,
    # optimizer parameters
    learning_rate=0.03,
    reduce_on_plateau_patience=4,
    output_size=o_s
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# fit the model on the data - redefine the model with the correct learning rate if necessary
trainer.fit(
    tft, train_dataloader=train_dataloader#, val_dataloaders=val_dataloader,
)
import time

start = time.time()
with open("model.mdl","wb") as f:
    torch.save(tft.state_dict(),f)
stop = time.time()
print("Save time",stop-start)

start = time.time()
with open("model.mdl","rb") as f:
    st_dict = torch.load(f)
stop = time.time()
print("torch load time",stop-start)
start = time.time()
tft.load_state_dict(st_dict)
stop = time.time()
print("load_state_dict time",stop-start)
sanikolov commented 1 year ago

sorry for butting in, I was just looking for examples of TemporalFusionTransformer model training and came across this issue. When I ran the code above I got

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

It seems like the documentation for TemporalFusionTransformer is not maintained any more. E.g. this tutorial https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html is no longer useable. This issue has not been active for the last 18 months either.