Under certain conditions, "huber_ipw_forest_cf" and "huber_ipw_forest" may return some NaN. The error is due to a division by 0 in lines 290 to 296 because either 1-p_x = 0 or p_x = 0. These forest estimators return a random result even for a fixed seed so the bug is quite rare, but reproducible for a hundred iterations or so.
Code :
import numpy as np
from numpy.random import default_rng
from med_bench.src.get_simulated_data import simulate_data
from med_bench.src.get_estimation import get_estimation
data = simulate_data(1000, default_rng(321), False, False, 1, 5, 123, "continuous", 0.5, 0.5, 0.5, 0.5)
x = data[0]
t = data[1].ravel()
m = data[2]
y = data[3].ravel()
estimator = "huber_ipw_forest"
for _ in range(1, 100):
res = get_estimation(x, t, m, y, estimator, 5)[0:5]
if np.any(np.isnan(res)):
print(res)
Output :
~/med_bench/src/benchmark_mediation.py:290: RuntimeWarning: invalid value encountered in true_divide
y1m1 = np.sum(y * t / p_x) / np.sum(t / p_x)
~/med_bench/src/benchmark_mediation.py:295: RuntimeWarning: divide by zero encountered in true_divide
y0m1 = np.sum(y * (1 - t) * p_xm / ((1 - p_xm) * p_x)) /\
~/med_bench/src/benchmark_mediation.py:296: RuntimeWarning: divide by zero encountered in true_divide
np.sum((1 - t) * p_xm / ((1 - p_xm) * p_x))
~/med_bench/src/benchmark_mediation.py:296: RuntimeWarning: invalid value encountered in double_scalars
np.sum((1 - t) * p_xm / ((1 - p_xm) * p_x))
(nan, nan, 2.7370075614318736, nan, nan)
In this case, I believe p_x = 0.
Note : It's really hard to make "huber_ipw_forest_cf" fail compared to "huber_ipw_forest".
Under certain conditions, "huber_ipw_forest_cf" and "huber_ipw_forest" may return some NaN. The error is due to a division by 0 in lines 290 to 296 because either
1-p_x = 0
orp_x = 0
. These forest estimators return a random result even for a fixed seed so the bug is quite rare, but reproducible for a hundred iterations or so.Code :
Output :
In this case, I believe
p_x = 0
. Note : It's really hard to make "huber_ipw_forest_cf" fail compared to "huber_ipw_forest".