ISI-MIP / attrici

Produce counterfactual climates for ISIMIP.
GNU General Public License v3.0
9 stars 2 forks source link

Out of memory error --> memory leak #99

Open SimonTreu opened 9 months ago

SimonTreu commented 9 months ago

There is a memory leak in the run_estimation function. It appears to happen in the for loop over different cells. Freeing memory by deleting the objects that are created within this loop did not help. Up to 100 mb is stored additionally for each iteration of the loop. This is a bug because nothing needs persist from one instance to the other. Some further research indicated that this problem might be rooted in the Theano package [1, 2]. That might also be the reason it did not come up in the updatet PYMC version. An open question is why this just came up now and not in earlier runs of the ATTRICI-PYMC3 code. I believe the reason is that we did not run so many cells within a single cell such that memory was never filled. A quick fix is to not run too many cells within one node. However this is not a permanent solution, because it is quite likely that others who want to use ATTRICI, run into the same problem. Likely they do not have the same setup with a "standby queue" where parallelization on many independent jobs is possible. A proper solution might be to setup the pymc model only once and reuse it for the next instances [see discussion in 1]. This might also give a considerable improvement in performance because the model does not need to be recompiled. Maybe this is also why the code is faster in the updated pymc because the compilation step is reduced. But this is just speculative for now.

I add the output of memory_profiler below

   111   1125.6 MiB      0.0 MiB          11       for n in run_numbers[:]:                                                                                                                                 
   112                                                 # todo only for debugging                                                                                                                            
   113   1125.6 MiB      0.0 MiB          11           if (n - start_num) >= 10:                                                                                                                            
   114   1125.6 MiB      0.0 MiB           1               break                                                                                                                                            
   115                                                                                                                                                                                                      
   116   1036.7 MiB      0.0 MiB          10           estimator = est.estimator(s)                                                                                                                         
   117   1036.7 MiB      4.5 MiB          10           sp = df_specs.loc[n, :]                                                                                                                              
   118                                                                                                                                                                                                      
   119                                                 # if lat >20: continue                                                                                                                               
   120   1036.7 MiB      0.0 MiB          10           print(                                                                                                                                               
   121   1036.7 MiB      0.0 MiB          10               "This is SLURM task",                                                                                                                            
   122   1036.7 MiB      0.0 MiB          10               task_id,                                                                                                                                         
   123   1036.7 MiB      0.0 MiB          10               "run number",                                                                                                                                    
   124   1036.7 MiB      0.0 MiB          10               n,                                                                                                                                               
   125   1036.7 MiB      0.0 MiB          10               "lat,lon",                                                                                                                                       
   126   1036.7 MiB      0.0 MiB          10               sp["lat"],                                                                                                                                       
   127   1036.7 MiB      0.0 MiB          10               sp["lon"],                                                                                                                                       
   128                                                 )                                                                                                                                                    
   129   1036.7 MiB      0.0 MiB          10           print(f"Memory usage at start: {memory_usage(-1, interval=0.1, timeout=1)} MiB")                                                                     
   130   1036.7 MiB      0.0 MiB          10           outdir_for_cell = dh.make_cell_output_dir(                                                                                                           
   131   1036.7 MiB      0.0 MiB          10               s.output_dir, "timeseries", sp["lat"], sp["lon"], s.variable                                                                                     
   132                                                 )                                                                                                                                                    
   133   1036.7 MiB      0.0 MiB          10           fname_cell = dh.get_cell_filename(outdir_for_cell, sp["lat"], sp["lon"], s)                                                                          
   134                                                                                                                                                                                                      
   135   1036.7 MiB      0.0 MiB          10           if s.skip_if_data_exists:                                                                                                                            
   136                                                     try:                                                                                                                                             
   137                                                         dh.test_if_data_valid_exists(fname_cell)                                                                                                     
   138                                                         print(f"Existing valid data in {fname_cell} . Skip calculation.")                                                                            
   139                                                         continue                                                                                                                                     
   140                                                     except Exception as e:                                                                                                                           
   141                                                         print(e)                                                                                                                                     
   142                                                         print("No valid data found. Run calculation.")                                                                                               
   143                                                                                                                                                                                                      
   144   1036.7 MiB      0.0 MiB          10           data = obs_data.variables[s.variable][:, sp["index_lat"], sp["index_lon"]]                                                                           
   145   1036.7 MiB      0.0 MiB          10           df, datamin, scale = dh.create_dataframe(                                                                                                            
   146   1036.7 MiB      8.2 MiB          10               nct[:], nct.units, data, gmt, s.variable                                                                                                         
   147                                                 )
   148      
   149   1036.7 MiB      0.0 MiB          10           try:
   150   1036.7 MiB      0.0 MiB          10               print(
   151   1036.7 MiB      0.0 MiB          10                   f"took {(datetime.now() - TIME0).total_seconds()} seconds till estimator.estimate_parameters is started"
   152                                                     )
   153   1036.7 MiB      0.0 MiB          10               trace, dff = func_timeout(
   154   1036.7 MiB      0.0 MiB          10                   s.timeout,
   155   1036.7 MiB      0.0 MiB          10                   estimator.estimate_parameters,
   156   1114.9 MiB    960.6 MiB          10                   args=(df, sp["lat"], sp["lon"], s.map_estimate, TIME0),
   157                                                     )
   158                                                 except (FunctionTimedOut, ValueError) as error:
   159                                                     raise error
   160                                                     # if str(error) == "Modes larger 1 are not allowed for the censored model.":
   161                                                     #     raise error
   162                                                     # else:
   163                                                     #     print("Sampling at", sp["lat"], sp["lon"], " timed out or failed.")
   164                                                     #     print(error)
   165                                                     #     logger.error(
   166                                                     #         str(
   167                                                     #             "lat,lon: "
   168                                                     #             + str(sp["lat"])
   169                                                     #             + " "
   170                                                     #             + str(sp["lon"])
   171                                                     #             + " : "
   172                                                     #             + str(error)
   173                                                     #         )
   174                                                     #     )
   175                                                     # continue
   176                                         
   177   1114.9 MiB      0.0 MiB          10           df_with_cfact = estimator.estimate_timeseries(
   178   1225.6 MiB    894.7 MiB          10               dff, trace, datamin, scale, s.map_estimate
   179                                                 )
   180   1225.6 MiB      0.0 MiB          10           dh.save_to_disk(
   181   1225.6 MiB      2.1 MiB          10               df_with_cfact, fname_cell, sp["lat"], sp["lon"], s.storage_format
   182                                                 )
   183   1225.6 MiB      0.0 MiB          10           df = None
   184   1125.6 MiB   -930.7 MiB          10           dff = None
   185                                         
   186   1125.6 MiB      0.0 MiB          10           print(f"Memory usage at end: {memory_usage(-1, interval=0.1, timeout=1)} MiB")
   187                                         
   188   1125.6 MiB      0.0 MiB           1       obs_data.close()
   189   1125.6 MiB      0.0 MiB           1       nc_lsmask.close()
   190   1125.6 MiB      0.0 MiB           1       print(
   191   1125.6 MiB      0.0 MiB           1           "Estimation completed for all cells. It took {0:.1f} minutes.".format(
   192   1125.6 MiB      0.0 MiB           1               (datetime.now() - TIME0).total_seconds() / 60
   193                                                 )
   194                                             )

Bellow is the run_estimation.py code I used for profiling

import logging
import os
from datetime import datetime
from pathlib import Path

import attrici
import attrici.datahandler as dh
import attrici.estimator as est
import netCDF4 as nc
import numpy as np
import pandas as pd
from func_timeout import FunctionTimedOut, func_timeout
from memory_profiler import memory_usage

import settings as s

print("memory usage at the beginning: {memory_usage(-1, interval=0.1, timeout=1)} MiB")
s.output_dir.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
    filename=s.output_dir / "failing_cells.log",
    level=logging.ERROR,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
logger = logging.getLogger(__name__)
# needed to silence verbose pymc3
pmlogger = logging.getLogger("pymc3")
pmlogger.propagate = False

print("Version", attrici.__version__)

try:
    submitted = os.environ["SUBMITTED"] == "1"
    task_id = int(os.environ["SLURM_ARRAY_TASK_ID"])
    njobarray = int(os.environ["SLURM_ARRAY_TASK_COUNT"])
    s.ncores_per_job = 1
    s.progressbar = False
except KeyError:
    submitted = False
    njobarray = 1
    task_id = 0
    s.progressbar = True

dh.create_output_dirs(s.output_dir)

gmt_file = s.input_dir / s.gmt_file
ncg = nc.Dataset(gmt_file, "r")
gmt = np.squeeze(ncg.variables["tas"][:])
ncg.close()

input_file = s.input_dir / s.source_file
landsea_mask_file = s.input_dir / s.landsea_file

obs_data = nc.Dataset(input_file, "r")
nc_lsmask = nc.Dataset(landsea_mask_file, "r")
nct = obs_data.variables["time"]
lats = obs_data.variables["lat"][:]
lons = obs_data.variables["lon"][:]
longrid, latgrid = np.meshgrid(lons, lats)
jgrid, igrid = np.meshgrid(np.arange(len(lons)), np.arange(len(lats)))

ls_mask = nc_lsmask.variables["mask"][:, :]
df_specs = pd.DataFrame()
df_specs["lat"] = latgrid[ls_mask == 1]
df_specs["lon"] = longrid[ls_mask == 1]
df_specs["index_lat"] = igrid[ls_mask == 1]
df_specs["index_lon"] = jgrid[ls_mask == 1]

print("A total of", len(df_specs), "grid cells to estimate.")

if len(df_specs) % (njobarray) == 0:
    print("Grid cells can be equally distributed to Slurm tasks")
    calls_per_arrayjob = np.ones(njobarray) * len(df_specs) // (njobarray)
else:
    print("Slurm tasks not a divisor of number of grid cells, discard some cores.")
    calls_per_arrayjob = np.ones(njobarray) * len(df_specs) // (njobarray) + 1
    discarded_jobs = np.where(np.cumsum(calls_per_arrayjob) > len(df_specs))
    calls_per_arrayjob[discarded_jobs] = 0
    calls_per_arrayjob[discarded_jobs[0][0]] = len(df_specs) - calls_per_arrayjob.sum()

assert calls_per_arrayjob.sum() == len(df_specs)
# print(calls_per_arrayjob)

# Calculate the starting and ending values for this task based
# on the SLURM task and the number of runs per task.
cum_calls_per_arrayjob = calls_per_arrayjob.cumsum(dtype=int)
start_num = 0 if task_id == 0 else cum_calls_per_arrayjob[task_id - 1]
end_num = cum_calls_per_arrayjob[task_id] - 1
run_numbers = np.arange(start_num, end_num + 1, 1, dtype=int)
if len(run_numbers) == 0:
    print("No runs assigned for this SLURM task.")
else:
    print("This is SLURM task", task_id, "which will do runs", start_num, "to", end_num)

estimator = est.estimator(s)

TIME0 = datetime.now()

for n in run_numbers[:]:

    sp = df_specs.loc[n, :]

    # if lat >20: continue
    print(
        "This is SLURM task", task_id, "run number", n, "lat,lon", sp["lat"], sp["lon"]
    )
    print("Memory usage at start: {memory_usage(-1, interval=0.1, timeout=1)} MiB")
    outdir_for_cell = dh.make_cell_output_dir(
        s.output_dir, "timeseries", sp["lat"], sp["lon"], s.variable
    )
    fname_cell = dh.get_cell_filename(outdir_for_cell, sp["lat"], sp["lon"], s)

    if s.skip_if_data_exists:
        try:
            dh.test_if_data_valid_exists(fname_cell)
            print(f"Existing valid data in {fname_cell} . Skip calculation.")
            continue
        except Exception as e:
            print(e)
            print("No valid data found. Run calculation.")

    data = obs_data.variables[s.variable][:, sp["index_lat"], sp["index_lon"]]
    df, datamin, scale = dh.create_dataframe(nct[:], nct.units, data, gmt, s.variable)

    try:
        print(
            f"took {(datetime.now() - TIME0).total_seconds()} seconds till estimator.estimate_parameters is started"
        )
        trace, dff = func_timeout(
            s.timeout,
            estimator.estimate_parameters,
            args=(df, sp["lat"], sp["lon"], s.map_estimate, TIME0),
        )
    except (FunctionTimedOut, ValueError) as error:
        raise error
        # if str(error) == "Modes larger 1 are not allowed for the censored model.":
        #     raise error
        # else:
        #     print("Sampling at", sp["lat"], sp["lon"], " timed out or failed.")
        #     print(error)
        #     logger.error(
        #         str(
        #             "lat,lon: "
        #             + str(sp["lat"])
        #             + " "
        #             + str(sp["lon"])
        #             + " : "
        #             + str(error)
        #         )
        #     )
        # continue

    df_with_cfact = estimator.estimate_timeseries(
        dff, trace, datamin, scale, s.map_estimate
    )
    dh.save_to_disk(df_with_cfact, fname_cell, sp["lat"], sp["lon"], s.storage_format)
    print("Memory usage at end: {memory_usage(-1, interval=0.1, timeout=1)} MiB")

obs_data.close()
nc_lsmask.close()
print(
    "Estimation completed for all cells. It took {0:.1f} minutes.".format(
        (datetime.now() - TIME0).total_seconds() / 60
    )
)