imatge-upc / SurvLIMEpy

Local interpretability for survival models
GNU General Public License v3.0
23 stars 5 forks source link

Error while executing the test example #26

Closed coockie273 closed 4 months ago

coockie273 commented 4 months ago

Hey,

I'm trying to run an example that is described in the readme file and I am having an error.

Code:


from survlimepy import SurvLimeExplainer
from survlimepy.load_datasets import Loader
from sksurv.linear_model import CoxPHSurvivalAnalysis

# Load the dataset
loader = Loader(dataset_name='veterans')
X, events, times = loader.load_data()

# Train a model
train, test = loader.preprocess_datasets(X, events, times)
model = CoxPHSurvivalAnalysis()
model.fit(train[0], train[1])

# Use SurvLimeExplainer class to find the feature importance
training_features = train[0]
training_events = [event for event, _ in train[1]]
training_times = [time for _, time in train[1]]

explainer = SurvLimeExplainer(
    training_features=training_features,
    training_events=training_events,
    training_times=training_times,
    model_output_times=model.event_times_,
)

# explanation variable will have the computed SurvLIME values
explanation = explainer.explain_instance(
    data_row=test[0].iloc[0],
    predict_fn=model.predict_cumulative_hazard_function,
    num_samples=1000,
)

I'm getting an error:

Traceback (most recent call last):
  File "C:\Users\fedot\PycharmProjects\Dyploma2\test.py", line 27, in <module>
    explanation = explainer.explain_instance(
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\survlimepy\survlime_explainer.py", line 172, in explain_instance
    b = opt_funcion_maker.solve_problem()
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\survlimepy\utils\optimisation.py", line 360, in solve_problem
    prob.solve(
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\problems\problem.py", line 503, in solve
    return solve_func(self, *args, **kwargs)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\problems\problem.py", line 1072, in _solve
    data, solving_chain, inverse_data = self.get_problem_data(
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\problems\problem.py", line 696, in get_problem_data
    data, inverse_data = solving_chain.apply(self, verbose)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\reductions\chain.py", line 76, in apply
    problem, inv = r.apply(problem)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\reductions\qp2quad_form\qp_matrix_stuffing.py", line 247, in apply
    params_to_P, params_to_q, flattened_variable = self.stuffed_objective(
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\reductions\qp2quad_form\qp_matrix_stuffing.py", line 232, in stuffed_objective
    params_to_P, params_to_q = extractor.quad_form(expr)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\utilities\coeff_extractor.py", line 263, in quad_form
    P = self.merge_P_list(P_list, P_height, num_params)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\utilities\coeff_extractor.py", line 303, in merge_P_list
    return combined.flatten_tensor(num_params)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\cvxpy\lin_ops\canon_backend.py", line 112, in flatten_tensor
    return sp.csc_matrix((self.data, (rows, cols)), shape=shape)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\scipy\sparse\_compressed.py", line 52, in __init__
    self._coo_container(arg1, shape=shape, dtype=dtype)
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\scipy\sparse\_coo.py", line 161, in __init__
    self._shape = check_shape((M, N))
  File "C:\Users\fedot\PycharmProjects\Dyploma2\venv\lib\site-packages\scipy\sparse\_sputils.py", line 313, in check_shape
    raise ValueError("'shape' elements cannot be negative")
ValueError: 'shape' elements cannot be negative
CarlosHernandezP commented 4 months ago

Sorry for the delay in the response,

Could you share which version are you using for Python and SurvLIMEpy's dependencies?

If you don't know how to get those, just give me the output of this code (make sure to run it in the virtual environment you are trying to use SurvLIMEpy's from!):

import importlib
import sys

def get_version(package_name):
    try:
        package = importlib.import_module(package_name)
        version = package.__version__
        print(f"{package_name} version: {version}")
    except ImportError:
        print(f"{package_name} is not installed.")
    except AttributeError:
        print(f"Could not determine the version of {package_name}.")

packages = ["numpy", "cvxpy", "scikit-survival", "scikit-learn", "pandas"]

print(f"Python version: {sys.version}\n")

for package in packages:
    get_version(package)

Edit: I was able to reproduce your error using a clean Python 3.9.19 environment. Most likely some of the dependencies have changed and that is what is causing the error. We will take a look.

Thank you for bringing this up.

coockie273 commented 4 months ago

Thank you for your response.

I have tried to run the code with Python versions 3.9.0 and 3.10.4, and I still received the same errors.

When I ran your code, I received the following output:


Python version: 3.9.0 (tags/v3.9.0:9cf6752, Oct  5 2020, 15:34:40) [MSC v.1927 64 bit (AMD64)]

numpy version: 1.26.4
cvxpy version: 1.5.1
scikit-survival is not installed.
scikit-learn is not installed.
pandas version: 2.2.2

and


Python version: 3.10.4 (tags/v3.10.4:9d38120, Mar 23 2022, 23:13:41) [MSC v.1929 64 bit (AMD64)]

numpy version: 1.26.4
cvxpy version: 1.4.3
scikit-survival is not installed.
scikit-learn is not installed.
pandas version: 2.2.2
CarlosHernandezP commented 4 months ago

Are you using Windows as your OS? It seems that this issue only arises on Windows systems, I was not able to replicate it in MacOS or Linux.

The error does not stem from SurvLIMEpy directly but it is caused by cvxpy inner workings. There have been multiple reports of this incident:

CVXPY issue #792 CVXPY issue #958

The problem originates from numpy sparce matrices that cvxpy is using. I don't remember the details, but somewhere the reshape function is called with -1 parameter. Numpy then calls np.prod function that fails on windows machines if the integer product goes beyond 2^32.

To circumvent this problem, reduce the number of neighbours computed (num_samples). I have updated the Read.me with 500 num_samples rather than 1000 which works on my windows machine.

I will be closing this issue as it is not directly related with survlimepy. Feel free to reopen it if the stated solution does not work for you.