judithabk6 / med_bench

BSD 3-Clause "New" or "Revised" License
8 stars 3 forks source link

huber_ipw_forest may return NaN #30

Open sami6mz opened 1 year ago

sami6mz commented 1 year ago

Under certain conditions, "huber_ipw_forest_cf" and "huber_ipw_forest" may return some NaN. The error is due to a division by 0 in lines 290 to 296 because either 1-p_x = 0 or p_x = 0. These forest estimators return a random result even for a fixed seed so the bug is quite rare, but reproducible for a hundred iterations or so.

Code :

import numpy as np
from numpy.random import default_rng
from med_bench.src.get_simulated_data import simulate_data
from med_bench.src.get_estimation import get_estimation

data = simulate_data(1000, default_rng(321), False, False, 1, 5, 123, "continuous", 0.5, 0.5, 0.5, 0.5)
x = data[0]
t = data[1].ravel()
m = data[2]
y = data[3].ravel()

estimator = "huber_ipw_forest"
for _ in range(1, 100):
    res = get_estimation(x, t, m, y, estimator, 5)[0:5]
    if np.any(np.isnan(res)):
        print(res)

Output :

~/med_bench/src/benchmark_mediation.py:290: RuntimeWarning: invalid value encountered in true_divide
  y1m1 = np.sum(y * t / p_x) / np.sum(t / p_x)
~/med_bench/src/benchmark_mediation.py:295: RuntimeWarning: divide by zero encountered in true_divide
  y0m1 = np.sum(y * (1 - t) * p_xm / ((1 - p_xm) * p_x)) /\
~/med_bench/src/benchmark_mediation.py:296: RuntimeWarning: divide by zero encountered in true_divide
  np.sum((1 - t) * p_xm / ((1 - p_xm) * p_x))
~/med_bench/src/benchmark_mediation.py:296: RuntimeWarning: invalid value encountered in double_scalars
  np.sum((1 - t) * p_xm / ((1 - p_xm) * p_x))
(nan, nan, 2.7370075614318736, nan, nan)

In this case, I believe p_x = 0. Note : It's really hard to make "huber_ipw_forest_cf" fail compared to "huber_ipw_forest".

houssamzenati commented 4 months ago

@brash6 @judithabk6 @bthirion @zbakhm