cerlymarco / linear-tree

A python library to build Model Trees with Linear Models at the leaves.
MIT License
338 stars 54 forks source link

finding breakpoint #18

Closed ZhengLiu1119 closed 2 years ago

ZhengLiu1119 commented 2 years ago

Hello,

thank you for your nice tool. I am using the function LinearTreeRegressor to draw a continuous piecewise linear. It works well, I am wondering, is it possible to show the location (the coordinates) of the breakpoints?

thank you

cerlymarco commented 2 years ago

Hi, thanks for your feedback

What do you mean by "location/coordinates"? Could u please provide a dummy example?

If u support the project, don't forget to leave a star ;-)

DanieleBaranzini commented 2 years ago

@cerlymarco ciao Marco, il tuo progetto is really spot on. Linear trees and derivations of such approach deserve more investigation for their clear potential ... kudos to you.

I ll'try to fork and contribute more on this soon

Daniele ergonomica

cerlymarco commented 2 years ago

Hi @DanieleBaranzini,

This would be a great pleasure!

ZhengLiu1119 commented 2 years ago

hi, thank you for the quick answer @cerlymarco

Here is an example:

import numpy as np
from sklearn.linear_model import LinearRegression
from lineartree import LinearTreeRegressor
import matplotlib.pyplot as plt

np.random.seed(9999)
x = np.random.normal(0, 1, 10000) * 10
y = np.where(x < -15, -2 * x +3, np.where(x <10, x + 48, -4 * x + 98)) + np.random.normal(0, 8, 10000)

y = y/(80/0.35)+0.04
x = x/(60/0.08)+0.945
pos = np.where(x<0.98)
y = y[pos]
x = x[pos]

plt.scatter(x, y, s = 5, color = u'b', marker = '.', label = 'scatter plt')
plt.grid()

#%% decision tree
model = LinearTreeRegressor(base_estimator=LinearRegression())
model.fit(x.reshape(-1,1), y.reshape(-1,1))
xd = np.linspace(0.9, 0.98, 1000)
yd = model.predict(xd.reshape(-1, 1))
plt.plot(xd, yd, "r", label = 'decision tree')

The final curve consists of some segments, which are connected with each other through breakpoints. I would like to determin the coordinate (x, y) for each breakpoint. Is it possible?

cerlymarco commented 2 years ago

@ZhengLiu1119, given this data:

np.random.seed(9999)
x = np.random.normal(0, 1, 10000) * 10
y = np.where(x < -15, -2 * x +3, np.where(x <10, x + 48, -4 * x + 98)) + np.random.normal(0, 8, 10000)

y = y/(80/0.35)+0.04
x = x/(60/0.08)+0.945
pos = np.where(x<0.98)
y = y[pos]
x = x[pos]

plt.scatter(x, y, s = 5, color = u'b', marker = '.', label = 'scatter plt')
plt.grid()

image

on which we build our LinearTree (with depth=3)

model = LinearTreeRegressor(base_estimator=LinearRegression(), max_depth=2)
model.fit(x.reshape(-1,1), y.reshape(-1,1))

in my opinion, you can easily use the apply function to map each element to each predicted leaves (this is the same function of DecisionTree from sklearn) and use the leaves numbers to hack the changepoints:

xd = np.linspace(0.9, 0.98, 1000)
yd = model.predict(xd.reshape(-1, 1))

leaves_mapping = model.apply(xd.reshape(-1, 1))
change_points = np.where(np.diff(leaves_mapping, prepend=leaves_mapping[0]) != 0)[0]

The final plot results in:

plt.scatter(x, y, s = 0.4, color = u'b', marker = '.', label = 'scatter plt')
plt.plot(xd, yd, "r", label = 'decision tree', linewidth=4)

for cp in change_points:
  plt.axvline(xd[cp], c='green', linewidth=2, linestyle='--')

plt.grid()

image

HERE is the running notebook.

All the best