sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.9k stars 618 forks source link

Huge Memory Consumption for TFT & Small Dataset #1322

Open vilorel opened 1 year ago

vilorel commented 1 year ago

Expected behavior

I followed this guide here which is mostly similar to yours except for a few changes in the trainer:

I then experimented with my own dataset and faced similar issues.

Actual behavior

To run the below example, I need again to use a server with 512GB RAM, and the RAM consumption rises up to about 74.5% and stays there throughout the training. The dataset is not that large, as you can see. What if I wanted to train 90M records or even a larger number? The model is also not that large IMHO. Am I missing something?

Code to reproduce the problem

I then tried my own example & test dataset to give you more concrete numbers:

[172801 rows x 9 columns]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 172801 entries, 0 to 172800
Data columns (total 9 columns):
 #   Column         Non-Null Count   Dtype   
---  ------         --------------   -----   
 0   time_idx       172801 non-null  int32   
 1   dow            172801 non-null  int8    
 2   hod            172801 non-null  int8    
 3   item           172801 non-null  category
 4   m0             172801 non-null  float32 
 5   m1             172801 non-null  float32 
 6   m2             172801 non-null  float32 
 7   m3             172801 non-null  float32 
 8   y              172801 non-null  float32 
dtypes: category(1), float32(5), int32(1), int8(2)
memory usage: 4.4 MB
ML Data Size: 172801
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/user/.local/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:197: UserWarning: Attribute 'loss' is an instance of nn.Module and is already saved during checkpointing. It is recommended to ignore them using self.save_hyperparameters(ignore=['loss']).
  rank_zero_warn(

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 1     
3  | prescalers                         | ModuleDict                      | 1.5 K 
4  | static_variable_selection          | VariableSelectionNetwork        | 104 K 
5  | encoder_variable_selection         | VariableSelectionNetwork        | 211 K 
6  | decoder_variable_selection         | VariableSelectionNetwork        | 104 K 
7  | static_context_variable_selection  | GatedResidualNetwork            | 66.3 K
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 66.3 K
9  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 66.3 K
10 | static_context_enrichment          | GatedResidualNetwork            | 66.3 K
11 | lstm_encoder                       | LSTM                            | 132 K 
12 | lstm_decoder                       | LSTM                            | 132 K 
13 | post_lstm_gate_encoder             | GatedLinearUnit                 | 33.0 K
14 | post_lstm_add_norm_encoder         | AddNorm                         | 256   
15 | static_enrichment                  | GatedResidualNetwork            | 82.7 K
16 | multihead_attn                     | InterpretableMultiHeadAttention | 41.2 K
17 | post_attn_gate_norm                | GateAddNorm                     | 33.3 K
18 | pos_wise_ff                        | GatedResidualNetwork            | 66.3 K
19 | pre_output_gate_norm               | GateAddNorm                     | 33.3 K
20 | output_layer                       | Linear                          | 903   
----------------------------------------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.961     Total estimated model params size (MB)

The configuration I used in this example is the following:

        lr_logger = LearningRateMonitor()
        logger = TensorBoardLogger(model_path)

        trainer = pl.Trainer(
            max_epochs=45,
            accelerator='cpu',
            devices=1,
            enable_model_summary=True,
            gradient_clip_val=0.1,
            callbacks=[lr_logger, early_stop_callback],
            logger=logger)

        tft = TemporalFusionTransformer.from_dataset(
            training,
            learning_rate=0.001,  # 0.001
            hidden_size=128, 
            hidden_continuous_size=64,  
            attention_head_size=4,
            dropout=0.1,
            output_size=7,
            loss=QuantileLoss(),
            logging_metrics=[MAE(), MeanSquaredError(), RMSE(), MAPE()],
            log_interval=10,
            reduce_on_plateau_patience=4)

        trainer.fit(
            tft,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader)

If it makes any difference, the number of workers for both dataloaders is set to 0 via the num_workers parameter.

JanuszJakubiec commented 1 year ago

I have a similar problem and I'm looking for a solution.

sayanb-7c6 commented 1 year ago

Are you using DDP mode? I think in that case, the log_interval=10 needs to change to log_interval=-1. There's a memory leak if I recall correctly. Issue: 486

furkanbr commented 8 months ago

Are you using DDP mode? I think in that case, the log_interval=10 needs to change to log_interval=-1. There's a memory leak if I recall correctly. Issue: 486

I am having same problem, I tried log_interval=-1 but did not make any difference. Also I am training on CPU with 51 GB Ram with Google Colab. Memory usage is rising at every batch iteration until it crash. Is there any other solution for this problem?

el-analista commented 6 months ago

My dataset has 7.4MM rows and 8 columns none of them categorical. With the following model after 1 single epoch I am at 210Gb RAM @_@

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    devices=4,
    enable_model_summary=True,
    strategy="auto",
    gradient_clip_val=0.1,
    # limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    default_root_dir="lightning_checkpoints/",
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03, # 0.03, 0.64
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="Ranger",
    reduce_on_plateau_patience=3,
)
IlIlllIIllIIlll commented 1 month ago

My dataset has 7.4MM rows and 8 columns none of them categorical. With the following model after 1 single epoch I am at 210Gb RAM @@我的数据集有 7.4MM 行和 8 列,它们都不是分类的。在 1 个单一纪元后使用以下模型,我的 RAM 为 210Gb @@

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    devices=4,
    enable_model_summary=True,
    strategy="auto",
    gradient_clip_val=0.1,
    # limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    default_root_dir="lightning_checkpoints/",
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03, # 0.03, 0.64
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="Ranger",
    reduce_on_plateau_patience=3,
)

This looks like TimeSeriesDataSet was made.

https://github.com/jdb78/pytorch-forecasting/issues/648#issuecomment-1999090208