Open kaalen opened 4 months ago
According to the error message I see that you want to explain 8 million variables. The matrix raising the memory error consumes O(N^2) with N
the number of input variables. It is not doable. You can use the attribute max_shap_value_inputs
to sample and approximate Shapley, but max_shap_value_inputs
should be superior to N and 8Millions (features) will still take a long time. In conclusion, I fear you need to reformulate the problem you want to solve before using this tool by reducing the number of input variables.
You can try the code below with different combinations of nb_features
, nb_events
, and max_shap_value_inputs
and see when it fits your problem and your computer.
import numpy as np
np.random.seed(42)
import pandas as pd
def gen_xy(nb_features=10, nb_events=10, x_type=np.float32, date_type=np.float32):
X_train=np.random.rand(nb_events, nb_features).astype(x_type)
np_time=np.random.rand(nb_events)
noise=np.clip(0,1, np.random.rand(nb_events)*0.01)
np_is_living=X_train[:,0] < np_time+noise # <--- dumb y
y_train=np.empty(nb_events, dtype=[('event', bool), ('time', date_type)])
y_train['event']=np_is_living
y_train['time']=np_time
X_train=pd.DataFrame(X_train,columns=['f'+str(i) for i in range(1,nb_features+1)])
return X_train, y_train
X_train, y_train = gen_xy(nb_events=1800)
X_test, y_test = gen_xy(nb_events=1800)
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(
n_estimators=10, max_depth=25, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=42
)
rsf.fit(X_train, y_train)
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
rsf_exp = SurvivalModelExplainer(rsf, X_test, y_test)
exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42, max_shap_value_inputs=20)
exp1_survshap_global_rsf.fit(rsf_exp)
I have no idea where the 8mio variables would be coming from. The dataset I was testing this on had only 23 input variables which is pretty basic.
I tried running the sample code you shared with max_shap_value_inputs
parameter. The sample as provided with max_shap_value_inputs=20 also fails to complete. It hangs at about 17%. No error and no progress after more than 2 hours. Not running any other significant processes on my machine at the time.
I tried with even smaller number of input variables (5) and training dataset with only 100 records. This time I get a LinAlgError: Singular matrix
.
---------------------------------------------------------------------------
LinAlgError Traceback (most recent call last)
Cell In[4], [line 12](vscode-notebook-cell:?execution_count=4&line=12)
[9](vscode-notebook-cell:?execution_count=4&line=9) rsf_exp = SurvivalModelExplainer(rsf, X_test, y_test)
[11](vscode-notebook-cell:?execution_count=4&line=11) exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42, max_shap_value_inputs=5)
---> [12](vscode-notebook-cell:?execution_count=4&line=12) exp1_survshap_global_rsf.fit(rsf_exp)
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\model_explanations\object.py:76, in ModelSurvSHAP.fit(self, explainer, new_observations, timestamps, save_individual_explanations, **kwargs)
[69](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:69) if new_observations is None:
[70](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:70) new_observations = explainer.data
[72](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:72) (
[73](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:73) self.full_result,
[74](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:74) self.individual_explanations,
[75](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:75) self.timestamps,
---> [76](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:76) ) = calculate_individual_explanations(
[77](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:77) explainer,
[78](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:78) new_observations,
[79](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:79) self.function_type,
[80](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:80) self.path,
[81](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:81) self.B,
[82](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:82) self.max_shap_value_inputs,
[83](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:83) self.random_state,
[84](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:84) self.calculation_method,
[85](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:85) self.aggregation_method,
[86](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:86) timestamps,
[87](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:87) save_individual_explanations,
[88](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:88) **kwargs
[89](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:89) )
[91](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:91) names = explainer.y.dtype.names
[92](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:92) self.event_ind = explainer.y[names[0]]
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\model_explanations\utils.py:127, in calculate_individual_explanations(explainer, new_observations, function_type, path, B, max_shap_value_inputs, random_state, calculation_method, aggregation_method, timestamps, save_individual_explanations, **kwargs)
[117](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:117) for i in tqdm(range(len(new_observations))):
[118](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:118) survSHAP_obj = PredictSurvSHAP(
[119](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:119) function_type=function_type,
[120](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:120) path=path,
(...)
[125](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:125) random_state=random_state,
[126](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:126) )
--> [127](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:127) survSHAP_obj.fit(explainer, new_observations.iloc[[i]], timestamps)
[128](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:128) if save_individual_explanations:
[129](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:129) individual_explanations.append(survSHAP_obj)
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\predict_explanations\object.py:81, in PredictSurvSHAP.fit(self, explainer, new_observation, timestamps, y_true)
[72](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:72) self.y_true_time = y_true[names[1]]
[74](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:74) if self.calculation_method == "kernel":
[75](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:75) (
[76](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:76) self.result,
[77](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:77) self.predicted_function,
[78](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:78) self.baseline_function,
[79](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:79) self.timestamps,
[80](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:80) self.r2,
---> [81](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:81) ) = shap_kernel(
[82](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:82) explainer,
[83](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:83) new_observation,
[84](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:84) self.function,
[85](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:85) self.aggregation_method,
[86](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:86) timestamps,
[87](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:87) self.max_shap_value_inputs,
[88](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:88) )
[89](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:89) elif self.calculation_method == "sampling":
[90](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:90) (
[91](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:91) self.result,
[92](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:92) self.predicted_function,
(...)
[104](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:104) self.exact,
[105](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:105) )
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\predict_explanations\utils.py:106, in shap_kernel(explainer, new_observation, function_type, aggregation_method, timestamps, max_shap_value_inputs)
[101](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:101) print(
[102](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:102) f"Approximate Survival Shapley will sample only {max_shap_value_inputs} values instead of 2**{p} for Exact Shapley"
[103](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:103) )
[105](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:105) kernel_weights = generate_shap_kernel_weights(simplified_inputs, p)
--> [106](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:106) shap_values, r2 = calculate_shap_values(
[107](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:107) explainer,
[108](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:108) function_type,
[109](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:109) baseline_f,
[110](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:110) explainer.data,
[111](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:111) simplified_inputs,
[112](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:112) kernel_weights,
[113](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:113) new_observation,
[114](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:114) timestamps,
[115](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:115) )
[117](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:117) variable_names = explainer.data.columns
[118](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:118) result = prepare_result_df(new_observation, variable_names, shap_values, timestamps, aggregation_method)
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\predict_explanations\utils.py:160, in calculate_shap_values(model, function_type, avg_function, data, simplified_inputs, shap_kernel_weights, new_observation, timestamps)
[158](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:158) W = np.diag(shap_kernel_weights)
[159](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:159) X = np.array(simplified_inputs)
--> [160](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:160) R = np.linalg.inv(X.T @ W @ X) @ (X.T @ W)
[161](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:161) y = (
[162](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:162) make_prediction_for_simplified_input(model, function_type, data, simplified_inputs, new_observation, timestamps)
[163](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:163) - avg_function
[164](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:164) )
[165](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:165) shap_values = R @ y
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\numpy\linalg\linalg.py:561, in inv(a)
[559](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:559) signature = 'D->D' if isComplexType(t) else 'd->d'
[560](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:560) extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
--> [561](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:561) ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
[562](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:562) return wrap(ainv.astype(result_t, copy=False))
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\numpy\linalg\linalg.py:112, in _raise_linalgerror_singular(err, flag)
[111](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:111) def _raise_linalgerror_singular(err, flag):
--> [112](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:112) raise LinAlgError("Singular matrix")
LinAlgError: Singular matrix
Could you provide a minimal snippet of code that represents your data/code and raises the error?
If the computed matrix is not invertible, you may add a small amount of random noise in your data to avoid linear dependencies between columns.
@kaalen could you provide an update please? What did you do ?
Issue Description
I used a Random Survival Forest with 10 estimators and a max depth of 25 on approximately 1800 data samples. The full dataset otherwise contains approximately 200,000 data samples, but I intentionally only used a very small sample when I encountered this error. When attempting to fit a ModelSurvSHAP on this very small dummy random survival forest I encounter the following error:
MemoryError: Unable to allocate 512. TiB for an array with shape (8388608, 8388608) and data type float64
I'm using survshap version 0.4.2.
Minimal Reproducible Code Sample
Error Trace: