ysig / GraKeL

A scikit-learn compatible library for graph kernels
https://ysig.github.io/GraKeL/
Other
593 stars 97 forks source link

graphletsampling kernel Runtime Warning: invalid value encountered in true_divide #67

Closed phillopski closed 2 years ago

phillopski commented 3 years ago

Not sure where this invalid value comes from which gives NaN values. Graphs were created from SMILE strings: def molg_from_smi(smiles): mol = Chem.MolFromSmiles(smiles) atom_with_idx = {i:atom.GetSymbol() for i, atom in enumerate(mol.GetAtoms())} bond_with_idx = {(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()): bond.GetBondTypeAsDouble() for bond in mol.GetBonds()} bond_with_idx.update({(bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()): bond.GetBondTypeAsDouble() for bond in mol.GetBonds()})

adj_m = Chem.GetAdjacencyMatrix(mol).tolist()

return grakel.Graph(adj_m, node_labels=atom_with_idx, edge_labels=bond_with_idx, graph_format='all')

When i ran the graphlet kernel, I was presented with this error message: C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\grakel\kernels\graphlet_sampling.py:324: RuntimeWarning: invalid value encountered in true_divide return np.divide(km, np.sqrt(np.outer(self._X_diag, self._X_diag))) C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\grakel\kernels\graphlet_sampling.py:283: RuntimeWarning: invalid value encountered in true_divide km /= np.sqrt(np.outer(Y_diag, X_diag)) Fitting 5 folds for each of 9 candidates, totalling 45 fits [CV] END ................................................C=1; total time= 0.0s C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\model_selection_validation.py:610: FitFailedWarning: Estimator fit failed. The score on this train-test partition for these parameters will be set to nan. Details: Traceback (most recent call last): File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\model_selection_validation.py", line 593, in _fit_and_score estimator.fit(X_train, y_train, fit_params) File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\svm_base.py", line 169, in fit X, y = self._validate_data(X, y, dtype=np.float64, File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\base.py", line 433, in _validate_data X, y = check_X_y(X, y, check_params) File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\utils\validation.py", line 63, in inner_f return f(*args, *kwargs) File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\utils\validation.py", line 814, in check_X_y X = check_array(X, accept_sparse=accept_sparse, File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\utils\validation.py", line 63, in inner_f return f(args, **kwargs) File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\utils\validation.py", line 663, in check_array _assert_all_finite(array, File "C:\Users\philw\anaconda3\envs\ResearchProject\lib\site-packages\sklearn\utils\validation.py", line 103, in _assert_all_finite raise ValueError( ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

ysig commented 3 years ago

Please make your example reproducible. What is the type of adj_m? What does it look like if you print it?

phillopski commented 3 years ago

This is the code im running in its entirety: from rdkit import Chem from math import sqrt import grakel from sklearn.model_selection import GridSearchCV from sklearn.metrics import mean_squared_error, r2_score import pandas as pd from sklearn.svm import SVR from sklearn.model_selection import train_test_split

Datasets

soly = pd.read_csv("AqSolDB_C.csv") val = pd.read_csv("validation dataset.csv")

Calculating Grakel Graph

def molg_from_smi(smiles): mol = Chem.MolFromSmiles(smiles) atom_with_idx = {i:atom.GetSymbol() for i, atom in enumerate(mol.GetAtoms())} bond_with_idx = {(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()): bond.GetBondTypeAsDouble() for bond in mol.GetBonds()} bond_with_idx.update({(bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()): bond.GetBondTypeAsDouble() for bond in mol.GetBonds()})

adj_m = Chem.GetAdjacencyMatrix(mol).tolist()

return grakel.Graph(adj_m, node_labels=atom_with_idx, edge_labels=bond_with_idx, graph_format='all')

SMILE strings from training, test and validation data

mols = soly.SMILES mols_val = val.SMILES

Creating graphs from Smiles

X = [molg_from_smi(mol) for mol in mols] X_val = [molg_from_smi(mol) for mol in mols_val]

Solubility measurements

Y = soly.Solubility y_val = val.Solubility

X_train, X_test, y_train, y_test = train_test_split(X,Y, test_size=0.2, random_state=42) gk = grakel.GraphletSampling(normalize=True)

Calculate the kernel matrices

K_train = gk.fit_transform(X_train) K_test = gk.transform(X_test) K_val = gk.transform(X_val)

Add SVR method

gsc = GridSearchCV( estimator=SVR(kernel='precomputed', max_iter=100000), param_grid={ 'C': [20, 21, 22, 23, 24, 25, 26, 27, 2**8], }, cv=5, scoring='neg_root_mean_squared_error', verbose=2, n_jobs=-1)

grid_result = gsc.fit(K_train, y_train) best_params = grid_result.bestparams print(best_params)

best_svr = SVR(kernel='precomputed', C= best_params["C"])

best_svr.fit(K_train, y_train)

Make predictions

y_pred_train = best_svr.predict(K_train) y_pred_test = best_svr.predict(K_test) y_pred_val = best_svr.predict(K_val)

Print model performance results

print('Training set RMSE: %.2f' % sqrt(mean_squared_error(y_train,y_pred_train))) print('Training set R2: %.2f' % r2_score(y_train,y_pred_train))

print('Test set RMSE: %.2f' % sqrt(mean_squared_error(y_test,y_pred_test))) print('Test set R2: %.2f' % r2_score(y_test,y_pred_test))

print('Validation set RMSE: %.2f' % sqrt(mean_squared_error(y_val,y_pred_val))) print('Validation set R2: %.2f' % r2_score(y_val,y_pred_val))

Attached is the dataset I'm using: AqSolDB_C.csv

The adjacency matrix here is undirected and shows the bond connectivity of a molecule. The example below shows the adjacency matrix of the molecule propane ('CCC') [[0 1 0] [1 0 1] [0 1 0]]

Could this be where the error stems from?

ysig commented 3 years ago

Can you please provide us with the minimum reproducing example possible? This could also help you possibly debug the code or highlight a problem of this package, if there is one. Thank you.