scikit-learn-contrib / py-earth

A Python implementation of Jerome Friedman's Multivariate Adaptive Regression Splines
http://contrib.scikit-learn.org/py-earth/
BSD 3-Clause "New" or "Revised" License
458 stars 121 forks source link

Is it possible to extract the remaining basis function from the algorithm for further calculation? #209

Open kwchau opened 3 years ago

kwchau commented 3 years ago

As stated in the title, is there a way to get the remaining basis after fitting as a list for further manipulation in Python? The basis_ attribute is an internal class that cannot be used while summary returns a string that is hard to extract information from. It would be great to have a list with the variables, left or right sided, and the cutoff points for each basis function in a list for further computations.

kevin-dietz commented 3 years ago

@kwchau It is possible but you have to do some manipulation of basis_. Something like the below code may get at what you are looking for. Wrote the snippet from memory so it may not get you all the way there but should get you pretty close.

varnames = []
pruned = []
coefs = []
i = 0

for bf in your_model.basis_:
    varnames.append(str(bf))
    pruned.append(bf.is_pruned())

    if bf.is_pruned() is False:
        coefs.append(your_model.coef_[0, i])
        i = i+1

    #Zero fill pruned coefs just to put something
    if bf.is_pruned() is not False:
        coefs.append(0)

#Construct Dataframe
summ = pd.DataFrame({'Variable' : varnames, 'Pruned': pruned, 'Coefficient': coefs})

#Remove Pruned Variables
summ = summ.loc[summ['Pruned'] == False]  

#Clean up variable name by removing h() 
summ['Variable'].replace("h" + r"\(",'', regex = True, inplace = True).replace(r"\)", '', regex = True, inplace = True)