jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.75k stars 599 forks source link

Memory leak during inference with TemporalFusionTransformer #1268

Open Taxel opened 1 year ago

Taxel commented 1 year ago

My issue

I am trying to run inference on our products, one product at a time (to merge the manual predictions stored in a database). For this I am using the code below, where I am already working in batches of 5000, after which the model should be destroyed (Which is my attempt to solve the memory problem, but to no avail). The memory usage keeps increasing over the lifetime of the program, however, either until the total number of products is reached (sql_utils.load_top_n_artnums(29000) is what specifies the amount of products. I can't really go higher than that right now, because otherwise the following case hits) or the 50GB of RAM I have assigned to WSL run out and the process is killed for using too much memory.

I unfortunately cannot publish the model or source data, but maybe someone can see where I am going terribly wrong with this. The model does not seem to release the allocated RAM, even after predict_batch() is over and it should go out of scope.

Code to reproduce the problem

from datetime import datetime, timedelta
from pytorch_forecasting.models import TemporalFusionTransformer
import pandas as pd
import warnings
import sqlalchemy
from tqdm.auto import tqdm
import sql_utils
import gc
import numpy as np
import os
import psutil
from memory_profiler import profile

warnings.filterwarnings("ignore")  # avoid printing out absolute paths

# pretty much just tweaked https://github.com/jdb78/pytorch-forecasting/blob/master/docs/source/tutorials/stallion.ipynb

def load_nn():
    best_tft_path = "lightning_logs/lightning_logs/version_28/checkpoints/epoch=50-step=25500.ckpt"
    model = TemporalFusionTransformer.load_from_checkpoint(best_tft_path)
    return model

def load_all_data():
    data = pd.read_csv("nn_all_data.csv", dtype={
        "main_category": str,
        "parent_category": str,
        "artnum": str,
        "month": str,
    }, parse_dates=True)
    data["main_category"] = data["main_category"].astype("category")
    data["parent_category"] = data["parent_category"].astype("category")
    data["month"] = data["month"].astype("category")
    data["avg_amount_by_sku"] = data.groupby(["time_idx", "artnum"], observed=True)[
        "amount"].transform("mean")
    data["avg_amount"] = data.groupby(["time_idx"], observed=True)[
        "amount"].transform("mean")
    data["date"] = pd.to_datetime(data["date"])
    data["amount_log"] = np.log(data["amount"].clip(lower=0) + 1)
    data.dropna(inplace=True)
    return data

def create_prediction_data(data, artnum, FIRST_DAY_LAST_MONTH, max_encoder_length=36, min_encoder_length=4, max_prediction_length=9):
    encoder_data = data[data["artnum"] == artnum].tail(
        max_encoder_length).copy()
    # check encoder_data length is at least min_encoder_length
    if encoder_data.shape[0] < min_encoder_length:
        raise ValueError(
            f"Not enough data to create encoder data. Need at least {min_encoder_length} months of data.")
    # check if encoder_data["date"] contains FIRST_DAY_LAST_MONTH
    if pd.to_datetime(FIRST_DAY_LAST_MONTH) not in encoder_data["date"].values:
        raise ValueError(
            f"Data for article does not contain last month {FIRST_DAY_LAST_MONTH}.")
    # get last row of encoder_data
    last_row = encoder_data.tail(1)
    # create decoder_data by repeating last_row and incrementing the month
    decoder_data = pd.concat(
        [last_row.assign(date=lambda x: x.date + pd.offsets.MonthBegin(i))
            for i in range(1, max_prediction_length + 1)],
        ignore_index=True,
    )

    # add time index consistent with "data"
    decoder_data["time_idx"] = (
        decoder_data["date"].dt.year - 2005) * 12 + decoder_data["date"].dt.month

    # adjust additional time feature(s)
    decoder_data["month"] = decoder_data.date.dt.month.astype(
        str).astype("category")  # categories have be strings

    # combine encoder and decoder data
    new_prediction_data = pd.concat(
        [encoder_data, decoder_data], ignore_index=True)
    return new_prediction_data

def predict_for_artnum(artnum, data, best_tft, manual_predictions, FIRST_DAY_THIS_MONTH, FIRST_DAY_LAST_MONTH):
    try:
        filtered_prediction_data = create_prediction_data(
            data, artnum, FIRST_DAY_LAST_MONTH)
    except ValueError as e:
        return (False, None)
    try:

        new_raw_predictions = best_tft.predict(
            filtered_prediction_data, mode="quantiles")
        df = pd.DataFrame(new_raw_predictions[0].numpy().copy(), columns=[
                          "amount_2%", "amount_10%", "amount_25%", "amount_50%", "amount_75%", "amount_90%", "amount_98%"])
        df["artnum"] = artnum
        df["date"] = pd.date_range(FIRST_DAY_THIS_MONTH, periods=9, freq="MS")
        df.set_index(["artnum", "date"], inplace=True)
        try:
            manual = sql_utils.get_manual_predictions(
                artnum, manual_predictions).rename_axis("date")
            df = df.join(manual)
        except KeyError:
            pass
        # all_article_predictions.append(df.copy())
        return (True, df)
    except (AssertionError) as e:
        tqdm.write(f"filtered data shape: {filtered_prediction_data.shape}")
        tqdm.write(f"Error predicting for artnum {artnum}: {e}")
    return (False, None)

def predict_batch(pbar, artnums):
    best_tft = load_nn()
    data = load_all_data()
    manual_predictions = sql_utils.load_manual_predictions()
    FIRST_DAY_THIS_MONTH = datetime.today().replace(
        day=1, hour=0, minute=0, second=0, microsecond=0)
    FIRST_DAY_LAST_MONTH = (FIRST_DAY_THIS_MONTH -
                            timedelta(days=1)).replace(day=1)
    predictions = []
    for artnum in artnums:
        (success, df) = predict_for_artnum(artnum, data, best_tft,
                                           manual_predictions, FIRST_DAY_THIS_MONTH, FIRST_DAY_LAST_MONTH)
        if success:
            predictions.append(df.copy())
        pbar.update(1)
    return pd.concat(predictions)

def main():
    # all_artnums = data["artnum"].unique()
    all_artnums = sql_utils.load_top_n_artnums(29000)
    batch_size = 5000

    all_article_predictions = []

    with tqdm(total=len(all_artnums)) as pbar:
        for i in range(0, len(all_artnums), batch_size):
            pbar.set_description(f"Predicting batch {round(i / batch_size)}")
            artnums = all_artnums[i:i + batch_size]
            all_article_predictions.append(predict_batch(pbar,
                                                         artnums))

    all_predictions = pd.concat(all_article_predictions)
    gc.collect()
    process = psutil.Process(os.getpid())
    print(process.memory_info().rss / 1024 / 1024, "MB used")
    # save predictions as csv with current date. E.g. nn_predictions_2021-01-01.csv
    all_predictions.to_csv(
        f"nn_predictions_{datetime.today().strftime('%Y-%m-%d')}.csv")

if __name__ == "__main__":
    main()
whoishu commented 1 year ago

same issues

EthanReid commented 7 months ago

Bump

el-analista commented 3 months ago

same issue

oddcard2 commented 2 months ago

same issue (

shellrazer commented 1 month ago

same issue

WinstonPrivacy commented 3 weeks ago

same issue.