I am trying to run inference on our products, one product at a time (to merge the manual predictions stored in a database). For this I am using the code below, where I am already working in batches of 5000, after which the model should be destroyed (Which is my attempt to solve the memory problem, but to no avail).
The memory usage keeps increasing over the lifetime of the program, however, either until the total number of products is reached (sql_utils.load_top_n_artnums(29000) is what specifies the amount of products. I can't really go higher than that right now, because otherwise the following case hits) or the 50GB of RAM I have assigned to WSL run out and the process is killed for using too much memory.
I unfortunately cannot publish the model or source data, but maybe someone can see where I am going terribly wrong with this.
The model does not seem to release the allocated RAM, even after predict_batch() is over and it should go out of scope.
Code to reproduce the problem
from datetime import datetime, timedelta
from pytorch_forecasting.models import TemporalFusionTransformer
import pandas as pd
import warnings
import sqlalchemy
from tqdm.auto import tqdm
import sql_utils
import gc
import numpy as np
import os
import psutil
from memory_profiler import profile
warnings.filterwarnings("ignore") # avoid printing out absolute paths
# pretty much just tweaked https://github.com/jdb78/pytorch-forecasting/blob/master/docs/source/tutorials/stallion.ipynb
def load_nn():
best_tft_path = "lightning_logs/lightning_logs/version_28/checkpoints/epoch=50-step=25500.ckpt"
model = TemporalFusionTransformer.load_from_checkpoint(best_tft_path)
return model
def load_all_data():
data = pd.read_csv("nn_all_data.csv", dtype={
"main_category": str,
"parent_category": str,
"artnum": str,
"month": str,
}, parse_dates=True)
data["main_category"] = data["main_category"].astype("category")
data["parent_category"] = data["parent_category"].astype("category")
data["month"] = data["month"].astype("category")
data["avg_amount_by_sku"] = data.groupby(["time_idx", "artnum"], observed=True)[
"amount"].transform("mean")
data["avg_amount"] = data.groupby(["time_idx"], observed=True)[
"amount"].transform("mean")
data["date"] = pd.to_datetime(data["date"])
data["amount_log"] = np.log(data["amount"].clip(lower=0) + 1)
data.dropna(inplace=True)
return data
def create_prediction_data(data, artnum, FIRST_DAY_LAST_MONTH, max_encoder_length=36, min_encoder_length=4, max_prediction_length=9):
encoder_data = data[data["artnum"] == artnum].tail(
max_encoder_length).copy()
# check encoder_data length is at least min_encoder_length
if encoder_data.shape[0] < min_encoder_length:
raise ValueError(
f"Not enough data to create encoder data. Need at least {min_encoder_length} months of data.")
# check if encoder_data["date"] contains FIRST_DAY_LAST_MONTH
if pd.to_datetime(FIRST_DAY_LAST_MONTH) not in encoder_data["date"].values:
raise ValueError(
f"Data for article does not contain last month {FIRST_DAY_LAST_MONTH}.")
# get last row of encoder_data
last_row = encoder_data.tail(1)
# create decoder_data by repeating last_row and incrementing the month
decoder_data = pd.concat(
[last_row.assign(date=lambda x: x.date + pd.offsets.MonthBegin(i))
for i in range(1, max_prediction_length + 1)],
ignore_index=True,
)
# add time index consistent with "data"
decoder_data["time_idx"] = (
decoder_data["date"].dt.year - 2005) * 12 + decoder_data["date"].dt.month
# adjust additional time feature(s)
decoder_data["month"] = decoder_data.date.dt.month.astype(
str).astype("category") # categories have be strings
# combine encoder and decoder data
new_prediction_data = pd.concat(
[encoder_data, decoder_data], ignore_index=True)
return new_prediction_data
def predict_for_artnum(artnum, data, best_tft, manual_predictions, FIRST_DAY_THIS_MONTH, FIRST_DAY_LAST_MONTH):
try:
filtered_prediction_data = create_prediction_data(
data, artnum, FIRST_DAY_LAST_MONTH)
except ValueError as e:
return (False, None)
try:
new_raw_predictions = best_tft.predict(
filtered_prediction_data, mode="quantiles")
df = pd.DataFrame(new_raw_predictions[0].numpy().copy(), columns=[
"amount_2%", "amount_10%", "amount_25%", "amount_50%", "amount_75%", "amount_90%", "amount_98%"])
df["artnum"] = artnum
df["date"] = pd.date_range(FIRST_DAY_THIS_MONTH, periods=9, freq="MS")
df.set_index(["artnum", "date"], inplace=True)
try:
manual = sql_utils.get_manual_predictions(
artnum, manual_predictions).rename_axis("date")
df = df.join(manual)
except KeyError:
pass
# all_article_predictions.append(df.copy())
return (True, df)
except (AssertionError) as e:
tqdm.write(f"filtered data shape: {filtered_prediction_data.shape}")
tqdm.write(f"Error predicting for artnum {artnum}: {e}")
return (False, None)
def predict_batch(pbar, artnums):
best_tft = load_nn()
data = load_all_data()
manual_predictions = sql_utils.load_manual_predictions()
FIRST_DAY_THIS_MONTH = datetime.today().replace(
day=1, hour=0, minute=0, second=0, microsecond=0)
FIRST_DAY_LAST_MONTH = (FIRST_DAY_THIS_MONTH -
timedelta(days=1)).replace(day=1)
predictions = []
for artnum in artnums:
(success, df) = predict_for_artnum(artnum, data, best_tft,
manual_predictions, FIRST_DAY_THIS_MONTH, FIRST_DAY_LAST_MONTH)
if success:
predictions.append(df.copy())
pbar.update(1)
return pd.concat(predictions)
def main():
# all_artnums = data["artnum"].unique()
all_artnums = sql_utils.load_top_n_artnums(29000)
batch_size = 5000
all_article_predictions = []
with tqdm(total=len(all_artnums)) as pbar:
for i in range(0, len(all_artnums), batch_size):
pbar.set_description(f"Predicting batch {round(i / batch_size)}")
artnums = all_artnums[i:i + batch_size]
all_article_predictions.append(predict_batch(pbar,
artnums))
all_predictions = pd.concat(all_article_predictions)
gc.collect()
process = psutil.Process(os.getpid())
print(process.memory_info().rss / 1024 / 1024, "MB used")
# save predictions as csv with current date. E.g. nn_predictions_2021-01-01.csv
all_predictions.to_csv(
f"nn_predictions_{datetime.today().strftime('%Y-%m-%d')}.csv")
if __name__ == "__main__":
main()
My issue
I am trying to run inference on our products, one product at a time (to merge the manual predictions stored in a database). For this I am using the code below, where I am already working in batches of 5000, after which the model should be destroyed (Which is my attempt to solve the memory problem, but to no avail). The memory usage keeps increasing over the lifetime of the program, however, either until the total number of products is reached (
sql_utils.load_top_n_artnums(29000)
is what specifies the amount of products. I can't really go higher than that right now, because otherwise the following case hits) or the 50GB of RAM I have assigned to WSL run out and the process is killed for using too much memory.I unfortunately cannot publish the model or source data, but maybe someone can see where I am going terribly wrong with this. The model does not seem to release the allocated RAM, even after
predict_batch()
is over and it should go out of scope.Code to reproduce the problem